Merge pull request !24672 from huchunmei/master
This commit is contained in:
i-robot 2021-10-12 03:48:59 +00:00 committed by Gitee
commit cb236727ff
5 changed files with 11 additions and 10 deletions

View File

@ -180,13 +180,13 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
- number = check_is_number(number, int, "bias")
- number = check_is_number(number, int, "bias", "bias_class")
"""
prim_name = f'in \'{prim_name}\'' if prim_name else ''
arg_name = f'\'{arg_name}\'' if arg_name else 'Input value'
prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
raise ValueError(f'{prim_name} {arg_name} must be legal float, but got `{arg_value}`.')
return arg_value
raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
raise TypeError(f'{prim_name} type of {arg_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
@ -556,8 +556,8 @@ class Validator:
if isinstance(arg_val, type(mstype.tensor)):
arg_val = arg_val.element_type()
if not arg_val in valid_values:
raise TypeError(f'For \'{prim_name}\', the `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {arg_val}.')
raise TypeError(f'For \'{prim_name}\', the type of `{arg_key}` should be in {valid_values},'
f' but got {arg_val}.')
return arg
def _check_types_same(arg1, arg2):

View File

@ -906,7 +906,7 @@ class ResizeBilinear(Cell):
super(ResizeBilinear, self).__init__()
def construct(self, x, size=None, scale_factor=None, align_corners=False):
shape = bilinear(x.shape, size, scale_factor, align_corners)
shape = bilinear(x.shape, size, scale_factor, align_corners, self.cls_name)
resize_bilinear = P.ResizeBilinear(shape, align_corners)
return resize_bilinear(x)

View File

@ -884,7 +884,7 @@ class MatMul(Cell):
def construct(self, x1, x2):
x1_shape = self.shape_op(x1)
x2_shape = self.shape_op(x2)
check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2, self.cls_name)
matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
x1_dim, x2_dim = len(x1_shape), len(x2_shape)

View File

@ -446,7 +446,8 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
continue
output = fn(*args)
return output
raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given args.")
raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given args. Got (sigs, fn): {self.entries}, "
f"and (dtype, args): {types}.")
def register(self, *type_names):
"""

View File

@ -470,7 +470,7 @@ class _Reduce(PrimitiveWithInfer):
output_min_shape = _infer_shape_reduce(input_x['min_shape'], axis_v, self.keep_dims, self.name)
else:
if axis_v is None:
raise ValueError(f"For {self.name}, axis could not be none.")
raise ValueError(f"For {self.name}, the 'axis' cannot be None.")
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
output_max_shape = out_shape
output_min_shape = out_shape