forked from mindspore-Ecosystem/mindspore
commit
cb236727ff
|
@ -180,13 +180,13 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
|
|||
- number = check_is_number(number, int, "bias")
|
||||
- number = check_is_number(number, int, "bias", "bias_class")
|
||||
"""
|
||||
prim_name = f'in \'{prim_name}\'' if prim_name else ''
|
||||
arg_name = f'\'{arg_name}\'' if arg_name else 'Input value'
|
||||
prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
|
||||
arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
|
||||
if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
|
||||
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
|
||||
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
|
||||
raise ValueError(f'{prim_name} {arg_name} must be legal float, but got `{arg_value}`.')
|
||||
return arg_value
|
||||
raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
|
||||
raise TypeError(f'{prim_name} type of {arg_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
|
||||
|
||||
|
||||
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
|
||||
|
@ -556,8 +556,8 @@ class Validator:
|
|||
if isinstance(arg_val, type(mstype.tensor)):
|
||||
arg_val = arg_val.element_type()
|
||||
if not arg_val in valid_values:
|
||||
raise TypeError(f'For \'{prim_name}\', the `{arg_key}` should be in {valid_values},'
|
||||
f' but `{arg_key}` is {arg_val}.')
|
||||
raise TypeError(f'For \'{prim_name}\', the type of `{arg_key}` should be in {valid_values},'
|
||||
f' but got {arg_val}.')
|
||||
return arg
|
||||
|
||||
def _check_types_same(arg1, arg2):
|
||||
|
|
|
@ -906,7 +906,7 @@ class ResizeBilinear(Cell):
|
|||
super(ResizeBilinear, self).__init__()
|
||||
|
||||
def construct(self, x, size=None, scale_factor=None, align_corners=False):
|
||||
shape = bilinear(x.shape, size, scale_factor, align_corners)
|
||||
shape = bilinear(x.shape, size, scale_factor, align_corners, self.cls_name)
|
||||
resize_bilinear = P.ResizeBilinear(shape, align_corners)
|
||||
return resize_bilinear(x)
|
||||
|
||||
|
|
|
@ -884,7 +884,7 @@ class MatMul(Cell):
|
|||
def construct(self, x1, x2):
|
||||
x1_shape = self.shape_op(x1)
|
||||
x2_shape = self.shape_op(x2)
|
||||
check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
|
||||
check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2, self.cls_name)
|
||||
matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
|
||||
|
||||
x1_dim, x2_dim = len(x1_shape), len(x2_shape)
|
||||
|
|
|
@ -446,7 +446,8 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|||
continue
|
||||
output = fn(*args)
|
||||
return output
|
||||
raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given args.")
|
||||
raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given args. Got (sigs, fn): {self.entries}, "
|
||||
f"and (dtype, args): {types}.")
|
||||
|
||||
def register(self, *type_names):
|
||||
"""
|
||||
|
|
|
@ -470,7 +470,7 @@ class _Reduce(PrimitiveWithInfer):
|
|||
output_min_shape = _infer_shape_reduce(input_x['min_shape'], axis_v, self.keep_dims, self.name)
|
||||
else:
|
||||
if axis_v is None:
|
||||
raise ValueError(f"For {self.name}, axis could not be none.")
|
||||
raise ValueError(f"For {self.name}, the 'axis' cannot be None.")
|
||||
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
|
||||
output_max_shape = out_shape
|
||||
output_min_shape = out_shape
|
||||
|
|
Loading…
Reference in New Issue