forked from mindspore-Ecosystem/mindspore
optimized the code error msg
This commit is contained in:
parent
0f143d65b2
commit
f5dbafc1eb
|
@ -621,13 +621,14 @@ class Validator:
|
|||
return arg_type
|
||||
|
||||
@staticmethod
|
||||
def check_reduce_shape(ori_shape, shape, axis, prim_name):
|
||||
def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2):
|
||||
"""Checks whether shape is ori_shape reduced on axis"""
|
||||
axis_origin = axis
|
||||
axis = axis if isinstance(axis, Iterable) else (axis,)
|
||||
exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
|
||||
if list(shape) != exp_shape:
|
||||
raise ValueError(f"For '{prim_name}', the 'ori_shape' {ori_shape} reduce on 'axis' {axis} should be "
|
||||
f"{tuple(exp_shape)}, but got 'shape': {shape}.")
|
||||
raise ValueError(f"For '{prim_name}', the '{arg_name1}'.shape reduce on 'axis': {axis_origin} should "
|
||||
f"be equal to '{arg_name2}'.shape: {shape}, but got {ori_shape}.")
|
||||
|
||||
@staticmethod
|
||||
def check_astype_dtype(dtype):
|
||||
|
|
|
@ -36,8 +36,7 @@ def _valid_cell(cell, op_name=None):
|
|||
if issubclass(cell.__class__, Cell):
|
||||
return True
|
||||
msg_prefix = f"For '{op_name}'," if op_name else ""
|
||||
raise TypeError(f'{msg_prefix} each cell should be subclass of Cell. '
|
||||
f'Please check your code')
|
||||
raise TypeError(f'{msg_prefix} each cell should be subclass of Cell, but got {type(cell).__name__}.')
|
||||
|
||||
|
||||
def _get_prefix_and_index(cells):
|
||||
|
@ -382,7 +381,7 @@ class CellList(_CellListBase, Cell):
|
|||
cls_name = self.__class__.__name__
|
||||
if not isinstance(cells, list):
|
||||
raise TypeError(f"For '{cls_name}', the new cells wanted to append "
|
||||
f"should be instance of list.")
|
||||
f"should be instance of list, but got {type(cells).__name__}.")
|
||||
prefix, _ = _get_prefix_and_index(self._cells)
|
||||
for cell in cells:
|
||||
if _valid_cell(cell, cls_name):
|
||||
|
|
|
@ -1110,9 +1110,9 @@ class BCELoss(LossBase):
|
|||
|
||||
|
||||
@constexpr
|
||||
def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name):
|
||||
def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name, arg_name1, arg_name2):
|
||||
"""Internal function, used to check whether the reduced shape meets the requirements."""
|
||||
validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name)
|
||||
validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name, arg_name1, arg_name2)
|
||||
|
||||
|
||||
class CosineEmbeddingLoss(LossBase):
|
||||
|
@ -1173,7 +1173,7 @@ class CosineEmbeddingLoss(LossBase):
|
|||
_check_is_tensor('logits_x2', logits_x2, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
F.same_type_shape(logits_x1, logits_x2)
|
||||
_check_reduced_shape_valid(F.shape(logits_x1), F.shape(labels), (1,), self.cls_name)
|
||||
_check_reduced_shape_valid(F.shape(logits_x1), F.shape(labels), (1,), self.cls_name, "logits_x1", "labels")
|
||||
# if labels > 0, 1-cosine(logits_x1, logits_x2)
|
||||
# else, max(0, cosine(logits_x1, logits_x2)-margin)
|
||||
prod_sum = self.reduce_sum(logits_x1 * logits_x2, (1,))
|
||||
|
@ -1285,13 +1285,13 @@ def _check_ndim(logits_nidm, labels_ndim, prime_name=None):
|
|||
'''Internal function, used to check whether the dimension of logits and labels meets the requirements.'''
|
||||
msg_prefix = f'For \'{prime_name}\', the' if prime_name else "The"
|
||||
if logits_nidm < 2 or logits_nidm > 4:
|
||||
raise ValueError(f"{msg_prefix} dimensions of 'logits' should be in [2, 4], but got"
|
||||
raise ValueError(f"{msg_prefix} dimensions of 'logits' should be in [2, 4], but got "
|
||||
f"dimension of 'logits' {logits_nidm}.")
|
||||
if labels_ndim < 2 or labels_ndim > 4:
|
||||
raise ValueError(f"{msg_prefix} dimensions of 'labels' should be in [2, 4], but got"
|
||||
raise ValueError(f"{msg_prefix} dimensions of 'labels' should be in [2, 4], but got "
|
||||
f"dimension of 'labels' {labels_ndim}.")
|
||||
if logits_nidm != labels_ndim:
|
||||
raise ValueError(f"{msg_prefix} dimensions of 'logits' and 'labels' must be equal, but got"
|
||||
raise ValueError(f"{msg_prefix} dimensions of 'logits' and 'labels' must be equal, but got "
|
||||
f"dimension of 'logits' {logits_nidm} and dimension of 'labels' {labels_ndim}.")
|
||||
|
||||
@constexpr
|
||||
|
@ -1299,12 +1299,10 @@ def _check_channel_and_shape(logits, labels, prime_name=None):
|
|||
'''Internal function, used to check whether the channels or shape of logits and labels meets the requirements.'''
|
||||
msg_prefix = f'For \'{prime_name}\', the' if prime_name else "The"
|
||||
if logits == 1:
|
||||
raise ValueError(f"{msg_prefix} single channel prediction is not supported, but got {logits}.")
|
||||
raise ValueError(f"{msg_prefix} 'logits'.shape[1] cannot be one, but got {logits}.")
|
||||
if labels not in (1, logits):
|
||||
raise ValueError(f"{msg_prefix} channel of 'labels' must be one or the 'labels' must be the same as that of "
|
||||
f"the 'logits'. If there is only one channel, its value should be in the range [0, C-1], "
|
||||
f"where C is the number of classes "
|
||||
f"inferred from 'logits': C={logits}, but got 'labels': {labels}.")
|
||||
raise ValueError(f"{msg_prefix} 'labels'.shape[1] must be one or equal to 'logits'.shape[1]: {logits}, "
|
||||
f"but got {labels}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -22,7 +22,7 @@ from ...common import dtype as mstype
|
|||
from ..primitive import constexpr
|
||||
|
||||
|
||||
def get_broadcast_shape(x_shape, y_shape, prim_name, shape_type=""):
|
||||
def get_broadcast_shape(x_shape, y_shape, prim_name, shape_type="", arg_name1="x", arg_name2="y"):
|
||||
"""
|
||||
Doing broadcast between tensor x and tensor y.
|
||||
|
||||
|
@ -30,6 +30,9 @@ def get_broadcast_shape(x_shape, y_shape, prim_name, shape_type=""):
|
|||
x_shape (list): The shape of tensor x.
|
||||
y_shape (list): The shape of tensor y.
|
||||
prim_name (str): Primitive name.
|
||||
shape_type (str): The type of shape, optional values are "", "min_shape" and "max_shape".
|
||||
arg_name1 (str): The arg name of x_shape.
|
||||
arg_name2 (str): The arg name of y_shape.
|
||||
|
||||
Returns:
|
||||
List, the shape that broadcast between tensor x and tensor y.
|
||||
|
@ -64,11 +67,12 @@ def get_broadcast_shape(x_shape, y_shape, prim_name, shape_type=""):
|
|||
elif shape_type == "max_shape":
|
||||
broadcast_shape_back.append(min(x_shape[i], y_shape[i]))
|
||||
else:
|
||||
raise ValueError(f"For '{prim_name}', 'x_shape' and 'y_shape' are supposed to broadcast, "
|
||||
f"where broadcast means that "
|
||||
f"'x_shape[i] = 1 or -1 or y_shape[i] = 1 or -1 or x_shape[i] = y_shape[i]', "
|
||||
f"but now 'x_shape' and 'y_shape' can not broadcast, "
|
||||
f"got 'i': {i}, 'x_shape': {x_shape}, 'y_shape': {y_shape}.")
|
||||
raise ValueError(f"For '{prim_name}', '{arg_name1}'.shape and '{arg_name2}'.shape are supposed "
|
||||
f"to broadcast, where broadcast means that '{arg_name1}'.shape[i] = 1 or -1 "
|
||||
f"or '{arg_name2}'.shape[i] = 1 or -1 "
|
||||
f"or '{arg_name1}'.shape[i] = '{arg_name2}'.shape[i]', "
|
||||
f"but now '{arg_name1}'.shape and '{arg_name2}'.shape can not broadcast, "
|
||||
f"got 'i': {i}, '{arg_name1}'.shape: {x_shape}, '{arg_name2}'.shape: {y_shape}.")
|
||||
|
||||
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
||||
broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
|
||||
|
|
|
@ -1631,11 +1631,11 @@ class InvertPermutation(PrimitiveWithInfer):
|
|||
def __infer__(self, x):
|
||||
x_shp = x['shape']
|
||||
x_value = x['value']
|
||||
if mstype.issubclass_(x['dtype'], mstype.tensor):
|
||||
raise ValueError(f"For \'{self.name}\', the value of 'input_x' must be non-Tensor, but got {x['dtype']}")
|
||||
if x_value is None:
|
||||
raise ValueError(f"For '{self.name}', the value of 'input_x' can not be None, but got {x_value}.")
|
||||
validator.check_value_type("shape", x_shp, [tuple, list], self.name)
|
||||
if mstype.issubclass_(x['dtype'], mstype.tensor):
|
||||
raise ValueError(f"For \'{self.name}\', the value of 'input_x' must be non-Tensor, but got {x['dtype']}")
|
||||
for shp in x_shp:
|
||||
if shp:
|
||||
x_rank = len(np.array(x_value, np.int64).shape)
|
||||
|
@ -3416,11 +3416,11 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
continue
|
||||
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
|
||||
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
|
||||
raise IndexError(f"For '{self.name}', the 'strides' cannot be negative number and "
|
||||
f"'begin' should be in [-{x_shape[i]}, {x_shape[i]}) "
|
||||
raise IndexError(f"For '{self.name}', the 'strides[{i}]' cannot be negative number and "
|
||||
f"'begin[{i}]' should be in [-{x_shape[i]}, {x_shape[i]}) "
|
||||
f"when 'shrink_axis_mask' is greater than 0, "
|
||||
f"but got 'shrink_axis_mask': {self.shrink_axis_mask}, 'strides': {stride}, "
|
||||
f"'begin': {begin}.")
|
||||
f"but got 'shrink_axis_mask': {self.shrink_axis_mask}, "
|
||||
f"'strides[{i}]': {stride}, 'begin[{i}]': {begin}.")
|
||||
j += 1
|
||||
i += 1
|
||||
continue
|
||||
|
@ -3452,11 +3452,11 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
continue
|
||||
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
|
||||
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
|
||||
raise IndexError(f"For '{self.name}', the 'strides' cannot be negative number and "
|
||||
f"'begin' should be in [-{x_shape[i]}, {x_shape[i]}) "
|
||||
raise IndexError(f"For '{self.name}', the 'strides[{i}]' cannot be negative number and "
|
||||
f"'begin[{i}]' should be in [-{x_shape[i]}, {x_shape[i]}) "
|
||||
f"when 'shrink_axis_mask' is greater than 0, "
|
||||
f"but got 'shrink_axis_mask': {self.shrink_axis_mask}, 'strides': {stride}, "
|
||||
f"'begin': {begin}.")
|
||||
f"but got 'shrink_axis_mask': {self.shrink_axis_mask}, "
|
||||
f"'strides[{i}]': {stride}, 'begin[{i}]': {begin}.")
|
||||
j += 1
|
||||
i += 1
|
||||
continue
|
||||
|
@ -3593,8 +3593,8 @@ class DiagPart(PrimitiveWithInfer):
|
|||
def infer_shape(self, x_shape):
|
||||
if len(x_shape) % 2 != 0 or \
|
||||
not x_shape:
|
||||
raise ValueError(f"For \'{self.name}\', the rank of 'input_x' must be non-zero and even, "
|
||||
f"but got rank {len(x_shape)}, with shapes {x_shape}.")
|
||||
raise ValueError(f"For \'{self.name}\', the dimension of 'input_x' must be non-zero and even, "
|
||||
f"but got dimension {len(x_shape)}, with shapes {x_shape}.")
|
||||
length = len(x_shape) // 2
|
||||
for i in range(length):
|
||||
validator.check('input_shape[i + len(input_shape)/2]', x_shape[i + length],
|
||||
|
@ -3966,7 +3966,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
|
|||
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
|
||||
f"the dimension of 'input_x', but got the "
|
||||
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
|
||||
f"{len(indices_shape)}.")
|
||||
f"{len(input_x_shape)}.")
|
||||
|
||||
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
|
||||
if updates_shape_check != updates_shape:
|
||||
|
@ -4051,7 +4051,7 @@ class TensorScatterAdd(PrimitiveWithInfer):
|
|||
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
|
||||
f"the dimension of 'input_x', but got the "
|
||||
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
|
||||
f"{len(indices_shape)}.")
|
||||
f"{len(input_x_shape)}.")
|
||||
|
||||
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
|
||||
if updates_shape_check != updates_shape:
|
||||
|
@ -6039,7 +6039,8 @@ class EmbeddingLookup(PrimitiveWithCheck):
|
|||
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
||||
indices_shp = indices['shape']
|
||||
if not indices_shp:
|
||||
raise ValueError(f"For '{self.name}', the 'input_indices' should not be a scalar, but got {indices_shp}.")
|
||||
raise ValueError(f"For '{self.name}', the dimension of 'input_indices' should not "
|
||||
f"be zero, but got {len(indices_shp)}.")
|
||||
params_shp = params['shape']
|
||||
if len(params_shp) > 2:
|
||||
raise ValueError(f"For '{self.name}', the dimension of 'input_params' must <= 2, "
|
||||
|
@ -6272,7 +6273,7 @@ class MaskedSelect(PrimitiveWithCheck):
|
|||
self.init_prim_io_names(inputs=['x', 'mask'], outputs=['output'])
|
||||
|
||||
def check_shape(self, x_shape, mask_shape):
|
||||
get_broadcast_shape(x_shape, mask_shape, self.name)
|
||||
get_broadcast_shape(x_shape, mask_shape, self.name, arg_name1="x", arg_name2="mask")
|
||||
|
||||
def check_dtype(self, x_dtype, mask_dtype):
|
||||
validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)
|
||||
|
@ -6402,7 +6403,7 @@ class TensorScatterMax(PrimitiveWithInfer):
|
|||
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
|
||||
f"the dimension of 'input_x', but got the "
|
||||
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
|
||||
f"{len(indices_shape)}.")
|
||||
f"{len(input_x_shape)}.")
|
||||
|
||||
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
|
||||
if updates_shape_check != updates_shape:
|
||||
|
@ -6486,7 +6487,7 @@ class TensorScatterMin(PrimitiveWithInfer):
|
|||
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
|
||||
f"the dimension of 'input_x', but got the "
|
||||
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
|
||||
f"{len(indices_shape)}.")
|
||||
f"{len(input_x_shape)}.")
|
||||
|
||||
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
|
||||
if updates_shape_check != updates_shape:
|
||||
|
@ -6571,7 +6572,7 @@ class TensorScatterSub(PrimitiveWithInfer):
|
|||
raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
|
||||
f"the dimension of 'input_x', but got the "
|
||||
f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
|
||||
f"{len(indices_shape)}.")
|
||||
f"{len(input_x_shape)}.")
|
||||
|
||||
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
|
||||
if updates_shape_check != updates_shape:
|
||||
|
|
|
@ -1333,7 +1333,7 @@ class BatchNorm(PrimitiveWithInfer):
|
|||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, "
|
||||
f"but got the 'format' is {self.format} and "
|
||||
f"but got the 'data_format' is {self.format} and "
|
||||
f"the platform is {context.get_context('device_target')}.")
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
|
||||
|
@ -1481,13 +1481,13 @@ class Conv2D(Primitive):
|
|||
pad = (pad,) * 4
|
||||
else:
|
||||
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
|
||||
self.add_prim_attr("pad", pad)
|
||||
self.padding = pad
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
|
||||
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be zero when 'pad_mode' is not 'pad', "
|
||||
f"but got 'pad': {pad} and 'pad_mode': {pad_mode}.")
|
||||
f"but got 'pad': {self.pad} and 'pad_mode': {pad_mode}.")
|
||||
self.add_prim_attr("pad", pad)
|
||||
self.padding = pad
|
||||
if self.pad_mode == 'pad':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
|
@ -1496,7 +1496,7 @@ class Conv2D(Primitive):
|
|||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, "
|
||||
f"but got the 'format' is {self.format} "
|
||||
f"but got the 'data_format' is {self.format} "
|
||||
f"and platform is {context.get_context('device_target')}.")
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||
|
@ -1612,12 +1612,12 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
|||
pad = (pad,) * 4
|
||||
else:
|
||||
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
|
||||
self.add_prim_attr("pad", pad)
|
||||
self.padding = pad
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be zero when 'pad_mode' is not 'pad', "
|
||||
f"but got 'pad' is {pad} and 'pad_mode' is {pad_mode}.")
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be zero or (0, 0, 0, 0) when 'pad_mode' "
|
||||
f"is not \"pad\", but got 'pad' is {self.pad} and 'pad_mode' is {pad_mode}.")
|
||||
self.add_prim_attr("pad", pad)
|
||||
self.padding = pad
|
||||
if self.pad_mode == 'pad':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
|
@ -1705,7 +1705,7 @@ class _Pool(PrimitiveWithInfer):
|
|||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, "
|
||||
f"but got the 'format' is {self.format} and "
|
||||
f"but got the 'data_format' is {self.format} and "
|
||||
f"the platform is {context.get_context('device_target')}.")
|
||||
if not self.is_maxpoolwithargmax:
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
|
@ -2009,8 +2009,8 @@ class MaxPool3D(PrimitiveWithInfer):
|
|||
raise ValueError(f"For '{self.name}', attr 'pad_list' should be an positive int number or a tuple of "
|
||||
f"three or six positive int numbers, but got {len(self.pad_list)} numbers.")
|
||||
if self.pad_mode != 'CALCULATED' and self.pad_list != (0, 0, 0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', the 'pad_list' must be zero when 'pad_mode' is not \"pad\", "
|
||||
f"but got 'pad_list' is {self.pad_list} and 'pad_mode' is {pad_mode}.")
|
||||
raise ValueError(f"For '{self.name}', the 'pad_list' must be zero or (0, 0, 0, 0, 0, 0) when 'pad_mode' "
|
||||
f"is not \"pad\", but got 'pad_list' is {pad_list} and 'pad_mode' is {pad_mode}.")
|
||||
if self.pad_mode == 'CALCULATED':
|
||||
for item in self.pad_list:
|
||||
validator.check_non_negative_int(item, 'pad_list item', self.name)
|
||||
|
@ -2176,7 +2176,7 @@ class Conv2DBackpropInput(Primitive):
|
|||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, "
|
||||
f"but got the 'format' is {self.format} and "
|
||||
f"but got the 'data_format' is {self.format} and "
|
||||
f"the platform is {context.get_context('device_target')}.")
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
|
||||
|
@ -2190,12 +2190,12 @@ class Conv2DBackpropInput(Primitive):
|
|||
pad = (pad,) * 4
|
||||
else:
|
||||
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
|
||||
self.add_prim_attr("pad", pad)
|
||||
self.padding = pad
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be zero when 'pad_mode' is not \"pad\", "
|
||||
f"but got 'pad' is {pad} and 'pad_mode' is {pad_mode}.")
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be zero or (0, 0, 0, 0) when 'pad_mode' "
|
||||
f"is not \"pad\", but got 'pad' is {self.pad} and 'pad_mode' is {pad_mode}.")
|
||||
self.add_prim_attr("pad", pad)
|
||||
self.padding = pad
|
||||
if self.pad_mode == 'pad':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
|
@ -2317,8 +2317,8 @@ class BiasAdd(Primitive):
|
|||
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError(f"For '{self.name}', the \"NHWC\" format is only supported in GPU target, "
|
||||
f"but got the 'format' is {self.format} and "
|
||||
raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, "
|
||||
f"but got the 'data_format' is {self.format} and "
|
||||
f"the platform is {context.get_context('device_target')}.")
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
|
||||
|
@ -7708,7 +7708,7 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
validator.check_int(len(h_shape), 3, Rel.EQ, "h_shape", self.name)
|
||||
validator.check_int(len(c_shape), 3, Rel.EQ, "c_shape", self.name)
|
||||
if seq_shape is not None:
|
||||
raise ValueError(f"For '{self.name}', the dimension of 'seq_length' should be None, but got {seq_shape}.")
|
||||
raise ValueError(f"For '{self.name}', the 'seq_length' should be None.")
|
||||
|
||||
num_step, batch_size, input_size = x_shape
|
||||
hidden_size = w_shape[-1] // 4
|
||||
|
@ -8144,7 +8144,7 @@ class AvgPool3D(Primitive):
|
|||
pad = (pad,) * 6
|
||||
if len(pad) != 6:
|
||||
raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of "
|
||||
f"six positive int numbers, but got `{pad}`.")
|
||||
f"six positive int numbers, but got {self.pad}.")
|
||||
self.pad_list = pad
|
||||
self.add_prim_attr('pad_list', self.pad_list)
|
||||
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
||||
|
@ -8152,8 +8152,8 @@ class AvgPool3D(Primitive):
|
|||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
if self.pad_mode != 'PAD' and pad != (0, 0, 0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be (0, 0, 0, 0, 0, 0) when 'pad_mode' is not \"pad\", "
|
||||
f"but got 'pad' is {pad} and 'pad_mode' is {self.pad_mode}.")
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be zero or (0, 0, 0, 0, 0, 0) when 'pad_mode' "
|
||||
f"is not \"PAD\", but got 'pad' is {self.pad} and 'pad_mode' is {pad_mode}.")
|
||||
if self.pad_mode == 'PAD':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad or item of pad', self.name)
|
||||
|
@ -8289,16 +8289,16 @@ class Conv3D(PrimitiveWithInfer):
|
|||
pad = (pad,) * 6
|
||||
if len(pad) != 6:
|
||||
raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of "
|
||||
f"six positive int numbers, but got `{pad}`.")
|
||||
self.add_prim_attr("pad", pad)
|
||||
self.padding = pad
|
||||
f"six positive int numbers, but got {self.pad}.")
|
||||
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
if self.pad_mode != 'pad' and pad != (0, 0, 0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be (0, 0, 0, 0, 0, 0) when 'pad_mode' is not \"pad\", "
|
||||
f"but got 'pad' is {pad} and 'pad_mode' is {self.pad_mode}.")
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be zero or (0, 0, 0, 0, 0, 0) when 'pad_mode' "
|
||||
f"is not \"pad\", but got 'pad' is {self.pad} and 'pad_mode' is {pad_mode}.")
|
||||
self.add_prim_attr("pad", pad)
|
||||
self.padding = pad
|
||||
if self.pad_mode == 'pad':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
|
@ -8316,7 +8316,7 @@ class Conv3D(PrimitiveWithInfer):
|
|||
validator.check_equal_int(len(w_shape), 5, "weight rank", self.name)
|
||||
validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
|
||||
if b_shape is not None:
|
||||
raise ValueError(f"For '{self.name}', the 'bias' currently only support None, but got {b_shape}.")
|
||||
raise ValueError(f"For '{self.name}', the 'bias' currently only support None.")
|
||||
validator.check(f"x_shape[1] // group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
|
||||
validator.check('kernel_size', self.kernel_size, 'w_shape[1:4]', tuple(w_shape[2:]), Rel.EQ, self.name)
|
||||
|
@ -8778,15 +8778,15 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
pad = (pad,) * 6
|
||||
if len(pad) != 6:
|
||||
raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of "
|
||||
f"six positive int numbers, but got `{pad}`.")
|
||||
f"six positive int numbers, but got {self.pad}.")
|
||||
self.pad_list = pad
|
||||
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
if self.pad_mode != 'pad' and pad != (0, 0, 0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be (0, 0, 0, 0, 0, 0) when 'pad_mode' is not \"pad\", "
|
||||
f"but got 'pad' is {pad} and 'pad_mode' is {self.pad_mode}.")
|
||||
raise ValueError(f"For '{self.name}', the 'pad' must be zero or (0, 0, 0, 0, 0, 0) when 'pad_mode' "
|
||||
f"is not \"pad\", but got 'pad' is {self.pad} and 'pad_mode' is {pad_mode}.")
|
||||
|
||||
if self.pad_mode == 'pad':
|
||||
for item in self.pad_list:
|
||||
|
@ -8800,11 +8800,11 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
|
||||
self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name,
|
||||
allow_five=False, ret_five=True, greater_zero=False)
|
||||
output_padding = (self.output_padding[2], self.output_padding[3], self.output_padding[4])
|
||||
if self.pad_mode != 'pad' and output_padding != (0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', the 'output_padding' must be (0, 0, 0) "
|
||||
f"when 'pad_mode' is not \"pad\", "
|
||||
f"but got 'output_padding' is {output_padding} and 'pad_mode' is {self.pad_mode}.")
|
||||
output_padding_ = (self.output_padding[2], self.output_padding[3], self.output_padding[4])
|
||||
if self.pad_mode != 'pad' and output_padding_ != (0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', the 'output_padding' must be zero or (0, 0, 0) "
|
||||
f"when 'pad_mode' is not \"pad\", but got 'output_padding' is "
|
||||
f"{output_padding} and 'pad_mode' is {pad_mode}.")
|
||||
validator.check_int_range(self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2], 1, 343, Rel.INC_BOTH,
|
||||
'The product of height, width and depth of kernel_size belonging [1, 343]', self.name)
|
||||
validator.check_int_range(self.stride[0] * self.stride[1] * self.stride[2], 1, 343, Rel.INC_BOTH,
|
||||
|
|
|
@ -195,8 +195,10 @@ class Gamma(PrimitiveWithInfer):
|
|||
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||
Validator.check_tensor_dtype_valid("alpha", alpha["dtype"], [mstype.float32], self.name)
|
||||
Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name)
|
||||
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
|
||||
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
|
||||
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name,
|
||||
arg_name1="alpha", arg_name2="beta")
|
||||
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name,
|
||||
arg_name1="broadcast_alpha_beta", arg_name2="shape")
|
||||
out = {
|
||||
'shape': broadcast_shape,
|
||||
'dtype': mstype.float32,
|
||||
|
@ -258,7 +260,7 @@ class Poisson(PrimitiveWithInfer):
|
|||
for i, shape_i in enumerate(shape_v):
|
||||
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||
Validator.check_tensor_dtype_valid("mean", mean["dtype"], [mstype.float32], self.name)
|
||||
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
|
||||
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name, arg_name1="mean", arg_name2="shape")
|
||||
out = {
|
||||
'shape': broadcast_shape,
|
||||
'dtype': mstype.int32,
|
||||
|
|
|
@ -1204,8 +1204,8 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
|
|||
net = NetWork()
|
||||
with pytest.raises(IndexError) as ex:
|
||||
net(input_tensor)
|
||||
assert "'begin' should be in [-6, 6) when 'shrink_axis_mask' is greater than 0, " \
|
||||
"but got 'shrink_axis_mask': 7, 'strides': 1, 'begin': -7." in str(ex.value)
|
||||
assert "'begin[0]' should be in [-6, 6) when 'shrink_axis_mask' is greater than 0, " \
|
||||
"but got 'shrink_axis_mask': 7, 'strides[0]': 1, 'begin[0]': -7." in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -1227,8 +1227,8 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
|
|||
net = NetWork()
|
||||
with pytest.raises(IndexError) as ex:
|
||||
net(input_tensor)
|
||||
assert "'begin' should be in [-6, 6) when 'shrink_axis_mask' is greater than 0, " \
|
||||
"but got 'shrink_axis_mask': 7, 'strides': 1, 'begin': 6." in str(ex.value)
|
||||
assert "'begin[0]' should be in [-6, 6) when 'shrink_axis_mask' is greater than 0, " \
|
||||
"but got 'shrink_axis_mask': 7, 'strides[0]': 1, 'begin[0]': 6." in str(ex.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue