diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index 822c97cf8b3..a7c33132053 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -6233,20 +6233,20 @@ def dstack(inputs): [ 6. 12.]]] """ if not isinstance(inputs, (tuple, list)): - raise TypeError(f"For 'dstack', 'diff', 'inputs' must be list or tuple of tensors, but got {type(inputs)}") + raise TypeError(f"For 'dstack', 'inputs' must be list or tuple of tensors, but got {type(inputs)}") if not inputs: - raise TypeError(f"For 'dstack', 'diff', 'inputs' can not be empty.") + raise TypeError(f"For 'dstack', 'inputs' can not be empty.") trans_inputs = () for tensor in inputs: if not isinstance(tensor, Tensor): - raise TypeError(f"For 'dstack', 'diff', each elements of 'inputs' must be Tensor, but got {type(tensor)}") + raise TypeError(f"For 'dstack', each elements of 'inputs' must be Tensor, but got {type(tensor)}") if tensor.ndim <= 1: tensor = _expand(tensor, 2) if tensor.ndim == 2: tensor = P.ExpandDims()(tensor, 2) trans_inputs += (tensor,) if not trans_inputs: - raise ValueError("For 'dstack', 'diff', at least one tensor is needed to concatenate.") + raise ValueError("For 'dstack', at least one tensor is needed to concatenate.") return P.Concat(2)(trans_inputs) @@ -6627,15 +6627,15 @@ def vstack(inputs): [1 5 9]] """ if not isinstance(inputs, (tuple, list)): - msg = f"List or tuple of tensors are required, but got {type(inputs)}" + msg = f"For 'vstack', list or tuple of tensors are required, but got {type(inputs)}" raise TypeError(msg) if not inputs: - msg = "Inputs can not be empty" + msg = "For 'vstack', inputs can not be empty" raise TypeError(msg) trans_tup = () for tensor in inputs: if not isinstance(tensor, Tensor): - msg = f"Tensor is required, but got {type(tensor)}" + msg = f"For 'vstack', Tensor is required, but got {type(tensor)}" raise TypeError(msg) if tensor.ndim <= 1: shape = P.Shape()(tensor) @@ -6647,7 +6647,7 @@ def vstack(inputs): tensor = P.Reshape()(tensor, tuple(shape)) trans_tup += (tensor,) if not trans_tup: - raise ValueError("Need at least one tensor to concatenate.") + raise ValueError("For 'vstack', need at least one tensor to concatenate.") out = P.Concat(0)(trans_tup) return out @@ -6866,6 +6866,11 @@ def copysign(x, other): return P.Select()(less_zero, P.Neg()(pos_tensor), pos_tensor) +@constexpr +def _check_non_negative_int(arg_value, arg_name, prim_name): + validator.check_non_negative_int(arg_value, arg_name, prim_name) + + def hann_window(window_length, periodic=True): r""" Generates a Hann Window. @@ -6897,19 +6902,12 @@ def hann_window(window_length, periodic=True): >>> print(out.asnumpy()) [0. 0.3454915 0.9045085 0.9045085 0.3454915] """ - if not isinstance(window_length, int): - raise TypeError( - f"For 'hann_window', 'window_length' must be a non-negative integer, but got {type(window_length)}" - ) - if window_length < 0: - raise ValueError( - f"For 'hann_window', 'window_length' must be a non-negative integer, but got {window_length}" - ) + _check_non_negative_int(window_length, 'window_length', 'hann_window') if window_length <= 1: return Tensor(np.ones(window_length)) if not isinstance(periodic, (bool, np.bool_)): raise TypeError( - f"For 'kaiser_window', 'periodic' must be a variable of Boolean type, but got {type(periodic)}" + f"For 'hann_window', 'periodic' must be a variable of Boolean type, but got {type(periodic)}" ) if periodic: window_length = window_length + 1 @@ -7834,14 +7832,7 @@ def kaiser_window(window_length, periodic=True, beta=12.0): [5.27734413e-05 1.01719688e-01 7.92939834e-01 7.92939834e-01 1.01719688e-01] """ - if not isinstance(window_length, int): - raise TypeError( - f"For 'kaiser_window', 'window_length' must be a non-negative integer, but got {type(window_length)}" - ) - if window_length < 0: - raise ValueError( - f"For 'kaiser_window', 'window_length' must be a non-negative integer, but got {window_length}" - ) + _check_non_negative_int(window_length, 'window_length', 'kaiser_window') if window_length <= 1: return Tensor(np.ones(window_length)) if not isinstance(periodic, bool):