forked from mindspore-Ecosystem/mindspore
!49682 Optimize function raise msg
Merge pull request !49682 from gaoshuanglong/fix_function_2023
This commit is contained in:
commit
8569558a59
|
@ -6233,20 +6233,20 @@ def dstack(inputs):
|
||||||
[ 6. 12.]]]
|
[ 6. 12.]]]
|
||||||
"""
|
"""
|
||||||
if not isinstance(inputs, (tuple, list)):
|
if not isinstance(inputs, (tuple, list)):
|
||||||
raise TypeError(f"For 'dstack', 'diff', 'inputs' must be list or tuple of tensors, but got {type(inputs)}")
|
raise TypeError(f"For 'dstack', 'inputs' must be list or tuple of tensors, but got {type(inputs)}")
|
||||||
if not inputs:
|
if not inputs:
|
||||||
raise TypeError(f"For 'dstack', 'diff', 'inputs' can not be empty.")
|
raise TypeError(f"For 'dstack', 'inputs' can not be empty.")
|
||||||
trans_inputs = ()
|
trans_inputs = ()
|
||||||
for tensor in inputs:
|
for tensor in inputs:
|
||||||
if not isinstance(tensor, Tensor):
|
if not isinstance(tensor, Tensor):
|
||||||
raise TypeError(f"For 'dstack', 'diff', each elements of 'inputs' must be Tensor, but got {type(tensor)}")
|
raise TypeError(f"For 'dstack', each elements of 'inputs' must be Tensor, but got {type(tensor)}")
|
||||||
if tensor.ndim <= 1:
|
if tensor.ndim <= 1:
|
||||||
tensor = _expand(tensor, 2)
|
tensor = _expand(tensor, 2)
|
||||||
if tensor.ndim == 2:
|
if tensor.ndim == 2:
|
||||||
tensor = P.ExpandDims()(tensor, 2)
|
tensor = P.ExpandDims()(tensor, 2)
|
||||||
trans_inputs += (tensor,)
|
trans_inputs += (tensor,)
|
||||||
if not trans_inputs:
|
if not trans_inputs:
|
||||||
raise ValueError("For 'dstack', 'diff', at least one tensor is needed to concatenate.")
|
raise ValueError("For 'dstack', at least one tensor is needed to concatenate.")
|
||||||
return P.Concat(2)(trans_inputs)
|
return P.Concat(2)(trans_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -6627,15 +6627,15 @@ def vstack(inputs):
|
||||||
[1 5 9]]
|
[1 5 9]]
|
||||||
"""
|
"""
|
||||||
if not isinstance(inputs, (tuple, list)):
|
if not isinstance(inputs, (tuple, list)):
|
||||||
msg = f"List or tuple of tensors are required, but got {type(inputs)}"
|
msg = f"For 'vstack', list or tuple of tensors are required, but got {type(inputs)}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
if not inputs:
|
if not inputs:
|
||||||
msg = "Inputs can not be empty"
|
msg = "For 'vstack', inputs can not be empty"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
trans_tup = ()
|
trans_tup = ()
|
||||||
for tensor in inputs:
|
for tensor in inputs:
|
||||||
if not isinstance(tensor, Tensor):
|
if not isinstance(tensor, Tensor):
|
||||||
msg = f"Tensor is required, but got {type(tensor)}"
|
msg = f"For 'vstack', Tensor is required, but got {type(tensor)}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
if tensor.ndim <= 1:
|
if tensor.ndim <= 1:
|
||||||
shape = P.Shape()(tensor)
|
shape = P.Shape()(tensor)
|
||||||
|
@ -6647,7 +6647,7 @@ def vstack(inputs):
|
||||||
tensor = P.Reshape()(tensor, tuple(shape))
|
tensor = P.Reshape()(tensor, tuple(shape))
|
||||||
trans_tup += (tensor,)
|
trans_tup += (tensor,)
|
||||||
if not trans_tup:
|
if not trans_tup:
|
||||||
raise ValueError("Need at least one tensor to concatenate.")
|
raise ValueError("For 'vstack', need at least one tensor to concatenate.")
|
||||||
out = P.Concat(0)(trans_tup)
|
out = P.Concat(0)(trans_tup)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -6866,6 +6866,11 @@ def copysign(x, other):
|
||||||
return P.Select()(less_zero, P.Neg()(pos_tensor), pos_tensor)
|
return P.Select()(less_zero, P.Neg()(pos_tensor), pos_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_non_negative_int(arg_value, arg_name, prim_name):
|
||||||
|
validator.check_non_negative_int(arg_value, arg_name, prim_name)
|
||||||
|
|
||||||
|
|
||||||
def hann_window(window_length, periodic=True):
|
def hann_window(window_length, periodic=True):
|
||||||
r"""
|
r"""
|
||||||
Generates a Hann Window.
|
Generates a Hann Window.
|
||||||
|
@ -6897,19 +6902,12 @@ def hann_window(window_length, periodic=True):
|
||||||
>>> print(out.asnumpy())
|
>>> print(out.asnumpy())
|
||||||
[0. 0.3454915 0.9045085 0.9045085 0.3454915]
|
[0. 0.3454915 0.9045085 0.9045085 0.3454915]
|
||||||
"""
|
"""
|
||||||
if not isinstance(window_length, int):
|
_check_non_negative_int(window_length, 'window_length', 'hann_window')
|
||||||
raise TypeError(
|
|
||||||
f"For 'hann_window', 'window_length' must be a non-negative integer, but got {type(window_length)}"
|
|
||||||
)
|
|
||||||
if window_length < 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"For 'hann_window', 'window_length' must be a non-negative integer, but got {window_length}"
|
|
||||||
)
|
|
||||||
if window_length <= 1:
|
if window_length <= 1:
|
||||||
return Tensor(np.ones(window_length))
|
return Tensor(np.ones(window_length))
|
||||||
if not isinstance(periodic, (bool, np.bool_)):
|
if not isinstance(periodic, (bool, np.bool_)):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"For 'kaiser_window', 'periodic' must be a variable of Boolean type, but got {type(periodic)}"
|
f"For 'hann_window', 'periodic' must be a variable of Boolean type, but got {type(periodic)}"
|
||||||
)
|
)
|
||||||
if periodic:
|
if periodic:
|
||||||
window_length = window_length + 1
|
window_length = window_length + 1
|
||||||
|
@ -7834,14 +7832,7 @@ def kaiser_window(window_length, periodic=True, beta=12.0):
|
||||||
[5.27734413e-05 1.01719688e-01 7.92939834e-01 7.92939834e-01
|
[5.27734413e-05 1.01719688e-01 7.92939834e-01 7.92939834e-01
|
||||||
1.01719688e-01]
|
1.01719688e-01]
|
||||||
"""
|
"""
|
||||||
if not isinstance(window_length, int):
|
_check_non_negative_int(window_length, 'window_length', 'kaiser_window')
|
||||||
raise TypeError(
|
|
||||||
f"For 'kaiser_window', 'window_length' must be a non-negative integer, but got {type(window_length)}"
|
|
||||||
)
|
|
||||||
if window_length < 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"For 'kaiser_window', 'window_length' must be a non-negative integer, but got {window_length}"
|
|
||||||
)
|
|
||||||
if window_length <= 1:
|
if window_length <= 1:
|
||||||
return Tensor(np.ones(window_length))
|
return Tensor(np.ones(window_length))
|
||||||
if not isinstance(periodic, bool):
|
if not isinstance(periodic, bool):
|
||||||
|
|
Loading…
Reference in New Issue