forked from mindspore-Ecosystem/mindspore
!49682 Optimize function raise msg
Merge pull request !49682 from gaoshuanglong/fix_function_2023
This commit is contained in:
commit
8569558a59
|
@ -6233,20 +6233,20 @@ def dstack(inputs):
|
|||
[ 6. 12.]]]
|
||||
"""
|
||||
if not isinstance(inputs, (tuple, list)):
|
||||
raise TypeError(f"For 'dstack', 'diff', 'inputs' must be list or tuple of tensors, but got {type(inputs)}")
|
||||
raise TypeError(f"For 'dstack', 'inputs' must be list or tuple of tensors, but got {type(inputs)}")
|
||||
if not inputs:
|
||||
raise TypeError(f"For 'dstack', 'diff', 'inputs' can not be empty.")
|
||||
raise TypeError(f"For 'dstack', 'inputs' can not be empty.")
|
||||
trans_inputs = ()
|
||||
for tensor in inputs:
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise TypeError(f"For 'dstack', 'diff', each elements of 'inputs' must be Tensor, but got {type(tensor)}")
|
||||
raise TypeError(f"For 'dstack', each elements of 'inputs' must be Tensor, but got {type(tensor)}")
|
||||
if tensor.ndim <= 1:
|
||||
tensor = _expand(tensor, 2)
|
||||
if tensor.ndim == 2:
|
||||
tensor = P.ExpandDims()(tensor, 2)
|
||||
trans_inputs += (tensor,)
|
||||
if not trans_inputs:
|
||||
raise ValueError("For 'dstack', 'diff', at least one tensor is needed to concatenate.")
|
||||
raise ValueError("For 'dstack', at least one tensor is needed to concatenate.")
|
||||
return P.Concat(2)(trans_inputs)
|
||||
|
||||
|
||||
|
@ -6627,15 +6627,15 @@ def vstack(inputs):
|
|||
[1 5 9]]
|
||||
"""
|
||||
if not isinstance(inputs, (tuple, list)):
|
||||
msg = f"List or tuple of tensors are required, but got {type(inputs)}"
|
||||
msg = f"For 'vstack', list or tuple of tensors are required, but got {type(inputs)}"
|
||||
raise TypeError(msg)
|
||||
if not inputs:
|
||||
msg = "Inputs can not be empty"
|
||||
msg = "For 'vstack', inputs can not be empty"
|
||||
raise TypeError(msg)
|
||||
trans_tup = ()
|
||||
for tensor in inputs:
|
||||
if not isinstance(tensor, Tensor):
|
||||
msg = f"Tensor is required, but got {type(tensor)}"
|
||||
msg = f"For 'vstack', Tensor is required, but got {type(tensor)}"
|
||||
raise TypeError(msg)
|
||||
if tensor.ndim <= 1:
|
||||
shape = P.Shape()(tensor)
|
||||
|
@ -6647,7 +6647,7 @@ def vstack(inputs):
|
|||
tensor = P.Reshape()(tensor, tuple(shape))
|
||||
trans_tup += (tensor,)
|
||||
if not trans_tup:
|
||||
raise ValueError("Need at least one tensor to concatenate.")
|
||||
raise ValueError("For 'vstack', need at least one tensor to concatenate.")
|
||||
out = P.Concat(0)(trans_tup)
|
||||
return out
|
||||
|
||||
|
@ -6866,6 +6866,11 @@ def copysign(x, other):
|
|||
return P.Select()(less_zero, P.Neg()(pos_tensor), pos_tensor)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_non_negative_int(arg_value, arg_name, prim_name):
|
||||
validator.check_non_negative_int(arg_value, arg_name, prim_name)
|
||||
|
||||
|
||||
def hann_window(window_length, periodic=True):
|
||||
r"""
|
||||
Generates a Hann Window.
|
||||
|
@ -6897,19 +6902,12 @@ def hann_window(window_length, periodic=True):
|
|||
>>> print(out.asnumpy())
|
||||
[0. 0.3454915 0.9045085 0.9045085 0.3454915]
|
||||
"""
|
||||
if not isinstance(window_length, int):
|
||||
raise TypeError(
|
||||
f"For 'hann_window', 'window_length' must be a non-negative integer, but got {type(window_length)}"
|
||||
)
|
||||
if window_length < 0:
|
||||
raise ValueError(
|
||||
f"For 'hann_window', 'window_length' must be a non-negative integer, but got {window_length}"
|
||||
)
|
||||
_check_non_negative_int(window_length, 'window_length', 'hann_window')
|
||||
if window_length <= 1:
|
||||
return Tensor(np.ones(window_length))
|
||||
if not isinstance(periodic, (bool, np.bool_)):
|
||||
raise TypeError(
|
||||
f"For 'kaiser_window', 'periodic' must be a variable of Boolean type, but got {type(periodic)}"
|
||||
f"For 'hann_window', 'periodic' must be a variable of Boolean type, but got {type(periodic)}"
|
||||
)
|
||||
if periodic:
|
||||
window_length = window_length + 1
|
||||
|
@ -7834,14 +7832,7 @@ def kaiser_window(window_length, periodic=True, beta=12.0):
|
|||
[5.27734413e-05 1.01719688e-01 7.92939834e-01 7.92939834e-01
|
||||
1.01719688e-01]
|
||||
"""
|
||||
if not isinstance(window_length, int):
|
||||
raise TypeError(
|
||||
f"For 'kaiser_window', 'window_length' must be a non-negative integer, but got {type(window_length)}"
|
||||
)
|
||||
if window_length < 0:
|
||||
raise ValueError(
|
||||
f"For 'kaiser_window', 'window_length' must be a non-negative integer, but got {window_length}"
|
||||
)
|
||||
_check_non_negative_int(window_length, 'window_length', 'kaiser_window')
|
||||
if window_length <= 1:
|
||||
return Tensor(np.ones(window_length))
|
||||
if not isinstance(periodic, bool):
|
||||
|
|
Loading…
Reference in New Issue