!23603 update error msg after code audit

Merge pull request !23603 from dinglinhe/dlh_code_ms_I43QY0_afteraudit
This commit is contained in:
i-robot 2021-09-22 17:24:48 +00:00 committed by Gitee
commit 3d0d0da8f9
10 changed files with 96 additions and 80 deletions

View File

@ -439,8 +439,8 @@ class Validator:
return arg_value return arg_value
arg_name = arg_name if arg_name else "Parameter" arg_name = arg_name if arg_name else "Parameter"
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,' raise ValueError(f"{msg_prefix} '{arg_name}' should be str and must be in '{valid_values}',"
f' but got `{arg_value}`.') f" but got '{arg_value}'.")
@staticmethod @staticmethod
def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):

View File

@ -155,7 +155,7 @@ class SequentialCell(Cell):
cell.update_parameters_name(name + ".") cell.update_parameters_name(name + ".")
self._is_dynamic_name.append(False) self._is_dynamic_name.append(False)
else: else:
raise TypeError(f"For '{self.__class__.__name__}', Cells must be list or orderedDict, " raise TypeError(f"For '{self.__class__.__name__}', the 'args[0]' must be list or orderedDict, "
f"but got {type(cells).__name__}") f"but got {type(cells).__name__}")
else: else:
for index, cell in enumerate(args): for index, cell in enumerate(args):
@ -369,7 +369,7 @@ class CellList(_CellListBase, Cell):
cls_name = self.__class__.__name__ cls_name = self.__class__.__name__
if not isinstance(cells, list): if not isinstance(cells, list):
raise TypeError(f"For '{cls_name}', the new cells wanted to append " raise TypeError(f"For '{cls_name}', the new cells wanted to append "
f"should be list of subcells.") f"should be instance of list.")
prefix, _ = _get_prefix_and_index(self._cells) prefix, _ = _get_prefix_and_index(self._cells)
for cell in cells: for cell in cells:
if _valid_cell(cell, cls_name): if _valid_cell(cell, cls_name):

View File

@ -291,7 +291,7 @@ class Conv2d(_Conv):
@constexpr @constexpr
def _check_input_3d(input_shape, op_name): def _check_input_3d(input_shape, op_name):
if len(input_shape) != 3: if len(input_shape) != 3:
raise ValueError(f"For '{op_name}', the shape of input should be 3d, but got shape {input_shape}") raise ValueError(f"For '{op_name}', the dimension of input should be 3d, but got {len(input_shape)}.")
class Conv1d(_Conv): class Conv1d(_Conv):
@ -480,7 +480,7 @@ class Conv1d(_Conv):
@constexpr @constexpr
def _check_input_5dims(input_shape, op_name): def _check_input_5dims(input_shape, op_name):
if len(input_shape) != 5: if len(input_shape) != 5:
raise ValueError(f"For '{op_name}', the input shape should be 5 dimensions, but got shape {input_shape}.") raise ValueError(f"For '{op_name}', the dimension of input should be 5d, but got {len(input_shape)}.")
class Conv3d(_Conv): class Conv3d(_Conv):

View File

@ -39,7 +39,7 @@ __all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
@constexpr @constexpr
def _check_input_2d(input_shape, param_name, func_name): def _check_input_2d(input_shape, param_name, func_name):
if len(input_shape) != 2: if len(input_shape) != 2:
raise ValueError(f"For '{func_name}', the '{param_name}' should be 2d, but got shape {input_shape}") raise ValueError(f"For '{func_name}', the dimension of '{param_name}' should be 2d, but got {len(input_shape)}")
return True return True
@ -301,7 +301,7 @@ class EmbeddingLookup(Cell):
if is_auto_parallel: if is_auto_parallel:
support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"] support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
raise ValueError("For '{}', the 'slice_mode' must be in {}, " raise ValueError("For '{}', the 'slice_mode' must be in {}, "
"but got {}.".format(self.cls_name, support_mode, slice_mode)) "but got \"{}\".".format(self.cls_name, support_mode, slice_mode))
if self.cache_enable and not enable_ps: if self.cache_enable and not enable_ps:
if parallel_mode != ParallelMode.STAND_ALONE: if parallel_mode != ParallelMode.STAND_ALONE:
raise ValueError(f"For '{self.cls_name}', parallel mode haven't supported cache enable yet.") raise ValueError(f"For '{self.cls_name}', parallel mode haven't supported cache enable yet.")
@ -353,8 +353,8 @@ class EmbeddingLookup(Cell):
full_batch = _get_full_batch() full_batch = _get_full_batch()
if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"): if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
raise ValueError(f"For '{self.cls_name}', the cache of parameter server parallel only be used " raise ValueError(f"For '{self.cls_name}', the cache of parameter server parallel only be used "
f"in \"full_batch\" and \"table_row_slice\" parallel strategy, but got " f"in \"full_batch\" and \"table_row_slice\" 'slice_mode', but got "
f"full_batch: {full_batch} and 'slice_mode': {slice_mode}.") f"full_batch: {full_batch} and 'slice_mode': \"{slice_mode}\".")
self.vocab_cache_size = self.vocab_cache_size * rank_size self.vocab_cache_size = self.vocab_cache_size * rank_size
_set_rank_id(rank_id) _set_rank_id(rank_id)
self.cache_enable = True self.cache_enable = True

View File

@ -117,7 +117,8 @@ def _get_dtype_max(dtype):
@constexpr @constexpr
def _check_input_4d(input_shape, param_name, func_name): def _check_input_4d(input_shape, param_name, func_name):
if len(input_shape) != 4: if len(input_shape) != 4:
raise ValueError(f"For '{func_name}', the '{param_name}' should be 4d, but got shape {input_shape}.") raise ValueError(f"For '{func_name}', the dimension of '{param_name}' should be 4d, "
f"but got {len(input_shape)}.")
return True return True

View File

@ -1203,9 +1203,9 @@ class BCEWithLogitsLoss(LossBase):
super(BCEWithLogitsLoss, self).__init__() super(BCEWithLogitsLoss, self).__init__()
self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction) self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction)
if isinstance(weight, Parameter): if isinstance(weight, Parameter):
raise TypeError(f"For '{self.cls_name}', the 'weight' can not be a parameter.") raise TypeError(f"For '{self.cls_name}', the 'weight' can not be a Parameter.")
if isinstance(pos_weight, Parameter): if isinstance(pos_weight, Parameter):
raise TypeError(f"For '{self.cls_name}', the 'pos_weight' can not be a parameter.") raise TypeError(f"For '{self.cls_name}', the 'pos_weight' can not be a Parameter.")
self.weight = weight self.weight = weight
self.pos_weight = pos_weight self.pos_weight = pos_weight
self.ones = P.OnesLike() self.ones = P.OnesLike()
@ -1228,7 +1228,7 @@ class BCEWithLogitsLoss(LossBase):
@constexpr @constexpr
def _check_ndim(logits_nidm, labels_ndim, prime_name=None): def _check_ndim(logits_nidm, labels_ndim, prime_name=None):
'''Internal function, used to check whether the dimension of logits and labels meets the requirements''' '''Internal function, used to check whether the dimension of logits and labels meets the requirements.'''
msg_prefix = f'For \'{prime_name}\', the' if prime_name else "The" msg_prefix = f'For \'{prime_name}\', the' if prime_name else "The"
if logits_nidm < 2 or logits_nidm > 4: if logits_nidm < 2 or logits_nidm > 4:
raise ValueError(f"{msg_prefix} dimensions of 'logits' should be in [2, 4], but got" raise ValueError(f"{msg_prefix} dimensions of 'logits' should be in [2, 4], but got"
@ -1248,8 +1248,9 @@ def _check_channel_and_shape(logits, labels, prime_name=None):
if logits == 1: if logits == 1:
raise ValueError(f"{msg_prefix} single channel prediction is not supported, but got {logits}.") raise ValueError(f"{msg_prefix} single channel prediction is not supported, but got {logits}.")
if labels not in (1, logits): if labels not in (1, logits):
raise ValueError(f"{msg_prefix} 'labels' must have a channel or the same shape as 'logits'." raise ValueError(f"{msg_prefix} channel of 'labels' must be one or the 'labels' must be the same as that of "
f"If it has a channel, it should be the range [0, C-1], where C is the number of classes " f"the 'logits'. If there is only one channel, its value should be in the range [0, C-1], "
f"where C is the number of classes "
f"inferred from 'logits': C={logits}, but got 'labels': {labels}.") f"inferred from 'logits': C={logits}, but got 'labels': {labels}.")

View File

@ -252,7 +252,7 @@ def is_same_type(inst, type_):
def check_valid_dim(dim, name): def check_valid_dim(dim, name):
"""Checks whether the dim is valid.""" """Checks whether the dim is valid."""
if dim not in (1, 2): if dim not in (1, 2):
raise ValueError(f"For '{name}', inputs dim must be 1d or 2d, but got {dim}.") raise ValueError(f"For '{name}', the dimension of inputs must be 1d or 2d, but got {dim}.")
@constexpr @constexpr

View File

@ -526,7 +526,7 @@ class Reshape(PrimitiveWithInfer):
if dim_prod <= 0: if dim_prod <= 0:
raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, " raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
f"the value of 'input_shape' is {shape_v}. " f"the value of 'input_shape' is {shape_v}. "
f"The product of shape of 'input_shape' should > 0, but got {dim_prod}.") f"The product of 'input_shape' should > 0, but got {dim_prod}.")
if neg_index != -1: if neg_index != -1:
shape_v[neg_index] = int(arr_prod / dim_prod) shape_v[neg_index] = int(arr_prod / dim_prod)
dim_prod *= shape_v[neg_index] dim_prod *= shape_v[neg_index]
@ -534,8 +534,8 @@ class Reshape(PrimitiveWithInfer):
raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, " raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
f"the value of 'input_shape' value is {shape_v}. " f"the value of 'input_shape' value is {shape_v}. "
f"The product of the shape of 'input_x' should be equal to product of 'input_shape', " f"The product of the shape of 'input_x' should be equal to product of 'input_shape', "
f"but product of the shape of 'input_x' is {arr_prod} " f"but product of the shape of 'input_x' is {arr_prod}, "
f", product of 'input_shape' is {dim_prod}.") f"product of 'input_shape' is {dim_prod}.")
value = None value = None
if x['value'] is not None: if x['value'] is not None:
value = Tensor(x['value'].asnumpy().reshape(shape_v)) value = Tensor(x['value'].asnumpy().reshape(shape_v))
@ -1081,9 +1081,9 @@ class Split(PrimitiveWithCheck):
# only validate when shape fully known # only validate when shape fully known
output_valid_check = x_shape[self.axis] % self.output_num output_valid_check = x_shape[self.axis] % self.output_num
if output_valid_check != 0: if output_valid_check != 0:
raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shape}, 'axis' is {self.axis}, " raise ValueError(f"For '{self.name}', the specified axis of 'input_x' should be divided exactly by "
f"the shape of 'input_x' in 'axis' {self.axis} is {x_shape[self.axis]}, " f"'output_num', but got the shape of 'input_x' in 'axis' {self.axis} is "
f"which must be divide exactly by 'output_num': {self.output_num}.") f"{x_shape[self.axis]}, 'output_num': {self.output_num}.")
size_splits = [x_shape[self.axis] // self.output_num] * self.output_num size_splits = [x_shape[self.axis] // self.output_num] * self.output_num
self.add_prim_attr('size_splits', size_splits) self.add_prim_attr('size_splits', size_splits)
@ -1603,7 +1603,7 @@ class InvertPermutation(PrimitiveWithInfer):
for shp in x_shp: for shp in x_shp:
if shp: if shp:
x_rank = len(np.array(x_value, np.int64).shape) x_rank = len(np.array(x_value, np.int64).shape)
raise ValueError(f"For \'{self.name}\', the length of 'input_x' must be 1, but got {x_rank}.") raise ValueError(f"For \'{self.name}\', the dimension of 'input_x' must be 1, but got {x_rank}.")
for i, value in enumerate(x_value): for i, value in enumerate(x_value):
validator.check_value_type("input[%d]" % i, value, [int], self.name) validator.check_value_type("input[%d]" % i, value, [int], self.name)
z = [x_value[i] for i in range(len(x_value))] z = [x_value[i] for i in range(len(x_value))]
@ -1978,7 +1978,8 @@ class Tile(PrimitiveWithInfer):
multiples_w = multiples_v multiples_w = multiples_v
elif len_sub < 0: elif len_sub < 0:
raise ValueError(f"For '{self.name}', the length of 'multiples' can not be smaller than " raise ValueError(f"For '{self.name}', the length of 'multiples' can not be smaller than "
f"the length of dimension in 'input_x'.") f"the dimension of 'input_x', but got length of 'multiples': {len(multiples_v)} "
f"and dimension of 'input_x': {len(x_shp)}.")
for i, a in enumerate(multiples_w): for i, a in enumerate(multiples_w):
x_shp[i] *= a x_shp[i] *= a
value = None value = None
@ -2736,7 +2737,7 @@ class Slice(PrimitiveWithInfer):
validator.check_non_negative_int(begin_v[i], f'input begin[{i}]') validator.check_non_negative_int(begin_v[i], f'input begin[{i}]')
if x_shape[i] < begin_v[i] + size_v[i]: if x_shape[i] < begin_v[i] + size_v[i]:
y = begin_v[i] + size_v[i] y = begin_v[i] + size_v[i]
raise ValueError(f"For '{self.name}', the sliced shape can not greater than origin shape, but got " raise ValueError(f"For '{self.name}', the sliced shape can not be greater than origin shape, but got "
f"sliced shape is {y}, and origin shape is {x_shape}.") f"sliced shape is {y}, and origin shape is {x_shape}.")
return {'shape': size_v, return {'shape': size_v,
'dtype': x['dtype'], 'dtype': x['dtype'],
@ -2924,9 +2925,9 @@ class Select(Primitive):
def _compute_slicing_length(begin, end, stride, x_shape, i): def _compute_slicing_length(begin, end, stride, x_shape, i):
"""Computes the length of the slicing.""" """Computes the length of the slicing."""
if i >= len(x_shape): if i >= len(x_shape):
raise ValueError(f"For 'StridedSlice', the index length must be less than or equal to " raise ValueError(f"For 'StridedSlice', the index must be less than or equal to "
f"the dimension of 'input_x' when there is no new axis, but got " f"the dimension of 'input_x', but got the dimension of 'input_x': {len(x_shape)} "
f"the dimension of 'input_x': {len(x_shape)} and the index length: {i}.") f"and the index: {i}.")
x_dim = x_shape[i] x_dim = x_shape[i]
if stride > 0: if stride > 0:
# When slicing forward, convert begin and end to positive numbers. # When slicing forward, convert begin and end to positive numbers.
@ -3233,8 +3234,10 @@ class StridedSlice(PrimitiveWithInfer):
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1': if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0: if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
raise IndexError(f"For '{self.name}', the 'strides' cannot be negative number and " raise IndexError(f"For '{self.name}', the 'strides' cannot be negative number and "
f"'begin' should be in [-{x_shape[i]}, {x_shape[i]}) when shrink axis, " f"'begin' should be in [-{x_shape[i]}, {x_shape[i]}) "
f"but got 'strides': {stride}, 'begin': {begin}.") f"when 'shrink_axis_mask' is greater than 0, "
f"but got 'shrink_axis_mask': {self.shrink_axis_mask}, 'strides': {stride}, "
f"'begin': {begin}.")
j += 1 j += 1
i += 1 i += 1
continue continue
@ -3267,8 +3270,10 @@ class StridedSlice(PrimitiveWithInfer):
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1': if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0: if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
raise IndexError(f"For '{self.name}', the 'strides' cannot be negative number and " raise IndexError(f"For '{self.name}', the 'strides' cannot be negative number and "
f"'begin' should be in [-{x_shape[i]}, {x_shape[i]}) when shrink axis, " f"'begin' should be in [-{x_shape[i]}, {x_shape[i]}) "
f"but got 'strides': {stride}, 'begin': {begin}.") f"when 'shrink_axis_mask' is greater than 0, "
f"but got 'shrink_axis_mask': {self.shrink_axis_mask}, 'strides': {stride}, "
f"'begin': {begin}.")
j += 1 j += 1
i += 1 i += 1
continue continue
@ -3759,11 +3764,11 @@ class TensorScatterUpdate(PrimitiveWithInfer):
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
if updates_shape_check != updates_shape: if updates_shape_check != updates_shape:
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to " raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
f"the shape of updates_shape_check, but got the shape of 'update': {updates_shape}," f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
f"and the shape of updates_shape_check: {updates_shape_check}. Please check the shape of " f"but got the shape of 'update': {updates_shape}, "
f"'indices' and 'input_x', they should be meeting followings formula:\n" f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
f" updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:].") f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
return input_x_shape return input_x_shape
@ -3844,11 +3849,11 @@ class TensorScatterAdd(PrimitiveWithInfer):
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
if updates_shape_check != updates_shape: if updates_shape_check != updates_shape:
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to " raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
f"the shape of updates_shape_check, but got the shape of 'update': {updates_shape}," f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
f"and the shape of updates_shape_check: {updates_shape_check}. Please check the shape of " f"but got the shape of 'update': {updates_shape}, "
f"'indices' and 'input_x', they should be meeting followings formula:\n" f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
f" updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:].") f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
return input_x_shape return input_x_shape
@ -5132,11 +5137,12 @@ class SpaceToBatchND(PrimitiveWithInfer):
padded = out_shape[i + offset] + self.paddings[i][0] + \ padded = out_shape[i + offset] + self.paddings[i][0] + \
self.paddings[i][1] self.paddings[i][1]
if padded % self.block_shape[i] != 0: if padded % self.block_shape[i] != 0:
msg_ndim = "2nd" if i + 2 == 2 else "3rd" raise ValueError(f"For '{self.name}', the padded should be divisible by 'block_shape', "
raise ValueError(f"For '{self.name}', the 2nd and 3rd dimension of the output tensor should be " f"where padded = input_x_shape[i + 2] + paddings[i][0] + paddings[i][1], "
f"divisible by 'block_shape', but got the {msg_ndim} dimension of output: {padded} " f"but got input_x_shape[{i + 2}]: {out_shape[i + offset]}, "
f"and the {i} dimension block_shape: {self.block_shape}. Please check the " f"paddings[{i}][0]: {self.paddings[i][0]} and paddings[{i}][1]: {self.paddings[i][1]}."
f"official homepage for more information about the output tensor.") f" Please check the official api documents for "
f"more information about the output tensor.")
out_shape[i + offset] = padded // self.block_shape[i] out_shape[i + offset] = padded // self.block_shape[i]
block_shape_prod = block_shape_prod * self.block_shape[i] block_shape_prod = block_shape_prod * self.block_shape[i]
out_shape[0] *= block_shape_prod out_shape[0] *= block_shape_prod
@ -5239,8 +5245,10 @@ class BatchToSpaceND(PrimitiveWithInfer):
out_shape[i + offset] = x_block_prod - crops_sum out_shape[i + offset] = x_block_prod - crops_sum
if out_shape[0] % block_shape_prod != 0: if out_shape[0] % block_shape_prod != 0:
raise ValueError(f"For '{self.name}', the 0th dimension of the output tensor should be " raise ValueError(f"For '{self.name}', the 0th dimension of the 'input_x' should be "
f"divisible by block_shape_prod, but got 0th dimension of the output tensor: " f"divisible by block_shape_prod, where block_shape_prod = "
f"'block_shape[0]' * 'block_shape[1]', "
f"but got 0th dimension of the 'input_x': "
f"{out_shape[0]} and the block_shape_prod: {block_shape_prod}.") f"{out_shape[0]} and the block_shape_prod: {block_shape_prod}.")
out_shape[0] = out_shape[0] // block_shape_prod out_shape[0] = out_shape[0] // block_shape_prod
return out_shape return out_shape
@ -6088,8 +6096,8 @@ class SearchSorted(PrimitiveWithInfer):
raise ValueError(f"For '{self.name}', the 'sequence' should be 1 dimensional or " raise ValueError(f"For '{self.name}', the 'sequence' should be 1 dimensional or "
f"all dimensions except the last dimension of 'sequence' " f"all dimensions except the last dimension of 'sequence' "
f"must be the same as all dimensions except the last dimension of 'values'. " f"must be the same as all dimensions except the last dimension of 'values'. "
f"but got dimension of 'sequence': {sequence_shape} " f"but got shape of 'sequence': {sequence_shape} "
f"and dimension of 'values': {values_shape}.") f"and shape of 'values': {values_shape}.")
return values_shape return values_shape
def infer_dtype(self, sequence_dtype, values_dtype): def infer_dtype(self, sequence_dtype, values_dtype):
@ -6166,11 +6174,11 @@ class TensorScatterMax(PrimitiveWithInfer):
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
if updates_shape_check != updates_shape: if updates_shape_check != updates_shape:
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to " raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
f"the shape of updates_shape_check, but got the shape of 'update': {updates_shape}," f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
f"and the shape of updates_shape_check: {updates_shape_check}. Please check the shape of " f"but got the shape of 'update': {updates_shape}, "
f"'indices' and 'input_x', they should be meeting followings formula:\n" f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
f" updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:].") f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
return input_x_shape return input_x_shape
@ -6250,11 +6258,11 @@ class TensorScatterMin(PrimitiveWithInfer):
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
if updates_shape_check != updates_shape: if updates_shape_check != updates_shape:
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to " raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
f"the shape of updates_shape_check, but got the shape of 'update': {updates_shape}," f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
f"and the shape of updates_shape_check: {updates_shape_check}. Please check the shape of " f"but got the shape of 'update': {updates_shape}, "
f"'indices' and 'input_x', they should be meeting followings formula:\n" f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
f" updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:].") f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
return input_x_shape return input_x_shape
@ -6335,11 +6343,11 @@ class TensorScatterSub(PrimitiveWithInfer):
updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
if updates_shape_check != updates_shape: if updates_shape_check != updates_shape:
raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to " raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
f"the shape of updates_shape_check, but got the shape of 'update': {updates_shape}," f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
f"and the shape of updates_shape_check: {updates_shape_check}. Please check the shape of " f"but got the shape of 'update': {updates_shape}, "
f"'indices' and 'input_x', they should be meeting followings formula:\n" f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
f" updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:].") f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
return input_x_shape return input_x_shape

View File

@ -1569,7 +1569,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
_, _, stride_h, stride_w = self.stride _, _, stride_h, stride_w = self.stride
_, _, dilation_h, dilation_w = self.dilation _, _, dilation_h, dilation_w = self.dilation
if kernel_size_n != 1: if kernel_size_n != 1:
raise ValueError(f"For '{self.name}', the batch of input weight should be 1, but got {kernel_size_n}") raise ValueError(f"For '{self.name}', the batch of 'weight' should be 1, but got {kernel_size_n}")
if self.pad_mode == "valid": if self.pad_mode == "valid":
h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h) h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w) w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
@ -1676,7 +1676,9 @@ class _Pool(PrimitiveWithInfer):
if shape_value <= 0: if shape_value <= 0:
raise ValueError(f"For '{self.name}', the each element of the output shape must be larger than 0, " raise ValueError(f"For '{self.name}', the each element of the output shape must be larger than 0, "
f"but got output shape: {out_shape}. The input shape: {x_shape}, " f"but got output shape: {out_shape}. The input shape: {x_shape}, "
f"kernel size: {self.kernel_size}, strides: {self.strides}.") f"kernel size: {self.kernel_size}, strides: {self.strides}."
f"Please check the official api documents for "
f"more information about the output.")
return out_shape return out_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
@ -3379,7 +3381,7 @@ class L2Normalize(PrimitiveWithInfer):
self.add_prim_attr('axis', axis) self.add_prim_attr('axis', axis)
self.init_attrs['axis'] = axis self.init_attrs['axis'] = axis
if len(axis) != 1: if len(axis) != 1:
raise TypeError(f"For '{self.name}', the dimension of 'axis' must be 1, but got {len(axis)}, " raise TypeError(f"For '{self.name}', the length of 'axis' must be 1, but got {len(axis)}, "
f"later will support multiple axis!") f"later will support multiple axis!")
self.axis = axis self.axis = axis
@ -3868,7 +3870,7 @@ class PReLU(PrimitiveWithInfer):
weight_dim = len(weight_shape) weight_dim = len(weight_shape)
if weight_dim != 1: if weight_dim != 1:
raise ValueError(f"For '{self.name}', the dimension of 'x' should be 1, while got {weight_dim}.") raise ValueError(f"For '{self.name}', the dimension of 'weight' should be 1, while got {weight_dim}.")
if weight_shape[0] != 1 and weight_shape[0] != channel_num: if weight_shape[0] != 1 and weight_shape[0] != channel_num:
raise ValueError(f"For '{self.name}', the first dimension of 'weight' should be (1,) or " raise ValueError(f"For '{self.name}', the first dimension of 'weight' should be (1,) or "
f"it should be equal to number of channels: {channel_num}, but got {weight_shape}") f"it should be equal to number of channels: {channel_num}, but got {weight_shape}")
@ -4193,8 +4195,8 @@ class Pad(PrimitiveWithInfer):
f"but got {type(paddings)}.") f"but got {type(paddings)}.")
for item in paddings: for item in paddings:
if len(item) != 2: if len(item) != 2:
raise ValueError(f"For '{self.name}', the shape of paddings must be (n, 2), " raise ValueError(f"For '{self.name}', the shape of 'paddings' must be (n, 2), "
f"but got {item}.") f"but got {paddings}.")
self.paddings = paddings self.paddings = paddings
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@ -4299,8 +4301,10 @@ class MirrorPad(PrimitiveWithInfer):
adjust = 1 adjust = 1
for i in range(0, int(paddings_size / 2)): for i in range(0, int(paddings_size / 2)):
if (paddings_value[i, 0] >= x_shape[i] + adjust) or (paddings_value[i, 1] >= x_shape[i] + adjust): if (paddings_value[i, 0] >= x_shape[i] + adjust) or (paddings_value[i, 1] >= x_shape[i] + adjust):
raise ValueError(f"For '{self.name}', both paddings[D, 0] and paddings[D, 1] must be no greater than " msg = "x_shape[D] + 1" if adjust == 1 else "x_shape[D]"
f"the dimension corresponding to 'x'.") raise ValueError(f"For '{self.name}', both paddings[D, 0] and paddings[D, 1] must be less than {msg}, "
f"but got paddings[{i}, 0]: {paddings[i, 0]}, "
f"paddings[{i}, 1]: {paddings[i, 1]}, x_shape[{i}]: {x_shape[i]}.")
y_shape = () y_shape = ()
for i in range(0, int(paddings_size / 2)): for i in range(0, int(paddings_size / 2)):
y_shape += ((x_shape[i] + paddings_value[i, 0] + paddings_value[i, 1]),) y_shape += ((x_shape[i] + paddings_value[i, 0] + paddings_value[i, 1]),)
@ -8131,7 +8135,7 @@ class AvgPool3D(Primitive):
pad = (pad,) * 6 pad = (pad,) * 6
if len(pad) != 6: if len(pad) != 6:
raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of " raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of "
f"six positive int numbers, but got `{len(pad)}`.") f"six positive int numbers, but got `{pad}`.")
self.pad_list = pad self.pad_list = pad
self.add_prim_attr('pad_list', self.pad_list) self.add_prim_attr('pad_list', self.pad_list)
validator.check_value_type('pad_mode', pad_mode, [str], self.name) validator.check_value_type('pad_mode', pad_mode, [str], self.name)
@ -8276,7 +8280,7 @@ class Conv3D(PrimitiveWithInfer):
pad = (pad,) * 6 pad = (pad,) * 6
if len(pad) != 6: if len(pad) != 6:
raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of " raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of "
f"six positive int numbers, but got `{len(pad)}`.") f"six positive int numbers, but got `{pad}`.")
self.add_prim_attr("pad", pad) self.add_prim_attr("pad", pad)
self.padding = pad self.padding = pad
validator.check_value_type('pad_mode', pad_mode, [str], self.name) validator.check_value_type('pad_mode', pad_mode, [str], self.name)
@ -8766,7 +8770,7 @@ class Conv3DTranspose(PrimitiveWithInfer):
pad = (pad,) * 6 pad = (pad,) * 6
if len(pad) != 6: if len(pad) != 6:
raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of " raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of "
f"six positive int numbers, but got `{len(pad)}`.") f"six positive int numbers, but got `{pad}`.")
self.pad_list = pad self.pad_list = pad
validator.check_value_type('pad_mode', pad_mode, [str], self.name) validator.check_value_type('pad_mode', pad_mode, [str], self.name)
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)

View File

@ -1204,7 +1204,8 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
net = NetWork() net = NetWork()
with pytest.raises(IndexError) as ex: with pytest.raises(IndexError) as ex:
net(input_tensor) net(input_tensor)
assert "'begin' should be in [-6, 6) when shrink axis, but got 'strides': 1, 'begin': -7." in str(ex.value) assert "'begin' should be in [-6, 6) when 'shrink_axis_mask' is greater than 0, " \
"but got 'shrink_axis_mask': 7, 'strides': 1, 'begin': -7." in str(ex.value)
@pytest.mark.level1 @pytest.mark.level1
@ -1226,7 +1227,8 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
net = NetWork() net = NetWork()
with pytest.raises(IndexError) as ex: with pytest.raises(IndexError) as ex:
net(input_tensor) net(input_tensor)
assert "'begin' should be in [-6, 6) when shrink axis, but got 'strides': 1, 'begin': 6." in str(ex.value) assert "'begin' should be in [-6, 6) when 'shrink_axis_mask' is greater than 0, " \
"but got 'shrink_axis_mask': 7, 'strides': 1, 'begin': -7." in str(ex.value)
@pytest.mark.level0 @pytest.mark.level0