!49682 Optimize function raise msg

Merge pull request !49682 from gaoshuanglong/fix_function_2023
This commit is contained in:
i-robot 2023-03-06 02:18:21 +00:00 committed by Gitee
commit 8569558a59
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 16 additions and 25 deletions

View File

@ -6233,20 +6233,20 @@ def dstack(inputs):
[ 6. 12.]]]
"""
if not isinstance(inputs, (tuple, list)):
raise TypeError(f"For 'dstack', 'diff', 'inputs' must be list or tuple of tensors, but got {type(inputs)}")
raise TypeError(f"For 'dstack', 'inputs' must be list or tuple of tensors, but got {type(inputs)}")
if not inputs:
raise TypeError(f"For 'dstack', 'diff', 'inputs' can not be empty.")
raise TypeError(f"For 'dstack', 'inputs' can not be empty.")
trans_inputs = ()
for tensor in inputs:
if not isinstance(tensor, Tensor):
raise TypeError(f"For 'dstack', 'diff', each elements of 'inputs' must be Tensor, but got {type(tensor)}")
raise TypeError(f"For 'dstack', each elements of 'inputs' must be Tensor, but got {type(tensor)}")
if tensor.ndim <= 1:
tensor = _expand(tensor, 2)
if tensor.ndim == 2:
tensor = P.ExpandDims()(tensor, 2)
trans_inputs += (tensor,)
if not trans_inputs:
raise ValueError("For 'dstack', 'diff', at least one tensor is needed to concatenate.")
raise ValueError("For 'dstack', at least one tensor is needed to concatenate.")
return P.Concat(2)(trans_inputs)
@ -6627,15 +6627,15 @@ def vstack(inputs):
[1 5 9]]
"""
if not isinstance(inputs, (tuple, list)):
msg = f"List or tuple of tensors are required, but got {type(inputs)}"
msg = f"For 'vstack', list or tuple of tensors are required, but got {type(inputs)}"
raise TypeError(msg)
if not inputs:
msg = "Inputs can not be empty"
msg = "For 'vstack', inputs can not be empty"
raise TypeError(msg)
trans_tup = ()
for tensor in inputs:
if not isinstance(tensor, Tensor):
msg = f"Tensor is required, but got {type(tensor)}"
msg = f"For 'vstack', Tensor is required, but got {type(tensor)}"
raise TypeError(msg)
if tensor.ndim <= 1:
shape = P.Shape()(tensor)
@ -6647,7 +6647,7 @@ def vstack(inputs):
tensor = P.Reshape()(tensor, tuple(shape))
trans_tup += (tensor,)
if not trans_tup:
raise ValueError("Need at least one tensor to concatenate.")
raise ValueError("For 'vstack', need at least one tensor to concatenate.")
out = P.Concat(0)(trans_tup)
return out
@ -6866,6 +6866,11 @@ def copysign(x, other):
return P.Select()(less_zero, P.Neg()(pos_tensor), pos_tensor)
@constexpr
def _check_non_negative_int(arg_value, arg_name, prim_name):
validator.check_non_negative_int(arg_value, arg_name, prim_name)
def hann_window(window_length, periodic=True):
r"""
Generates a Hann Window.
@ -6897,19 +6902,12 @@ def hann_window(window_length, periodic=True):
>>> print(out.asnumpy())
[0. 0.3454915 0.9045085 0.9045085 0.3454915]
"""
if not isinstance(window_length, int):
raise TypeError(
f"For 'hann_window', 'window_length' must be a non-negative integer, but got {type(window_length)}"
)
if window_length < 0:
raise ValueError(
f"For 'hann_window', 'window_length' must be a non-negative integer, but got {window_length}"
)
_check_non_negative_int(window_length, 'window_length', 'hann_window')
if window_length <= 1:
return Tensor(np.ones(window_length))
if not isinstance(periodic, (bool, np.bool_)):
raise TypeError(
f"For 'kaiser_window', 'periodic' must be a variable of Boolean type, but got {type(periodic)}"
f"For 'hann_window', 'periodic' must be a variable of Boolean type, but got {type(periodic)}"
)
if periodic:
window_length = window_length + 1
@ -7834,14 +7832,7 @@ def kaiser_window(window_length, periodic=True, beta=12.0):
[5.27734413e-05 1.01719688e-01 7.92939834e-01 7.92939834e-01
1.01719688e-01]
"""
if not isinstance(window_length, int):
raise TypeError(
f"For 'kaiser_window', 'window_length' must be a non-negative integer, but got {type(window_length)}"
)
if window_length < 0:
raise ValueError(
f"For 'kaiser_window', 'window_length' must be a non-negative integer, but got {window_length}"
)
_check_non_negative_int(window_length, 'window_length', 'kaiser_window')
if window_length <= 1:
return Tensor(np.ones(window_length))
if not isinstance(periodic, bool):