From 310f59a1477c7d58b547049f61b2c6a1f777b7ad Mon Sep 17 00:00:00 2001 From: fandawei Date: Sun, 13 Nov 2022 21:31:19 +0800 Subject: [PATCH] add ops.clip_by_value --- .../ops/mindspore.ops.func_clip_by_value.rst | 16 ++- .../mindspore/dataset/engine/offload.py | 11 +- mindspore/python/mindspore/nn/optim/thor.py | 5 +- .../probability/distribution/_utils/utils.py | 4 +- .../probability/distribution/categorical.py | 3 +- .../mindspore/ops/composite/__init__.py | 4 +- .../mindspore/ops/composite/clip_ops.py | 104 -------------- .../python/mindspore/ops/function/__init__.py | 5 + .../mindspore/ops/function/clip_func.py | 128 ++++++++++++++++++ .../test_dynamic_wenet_ascend.py | 2 +- .../src/bert_for_pre_training.py | 5 +- .../bert/bert_performance/src/bert_model.py | 8 +- .../models/bert/src/bert_for_pre_training.py | 5 +- .../st/networks/models/bert/src/bert_model.py | 8 +- tests/st/ops/test_ops_clip_func.py | 75 ++++++++++ tests/st/optimizer/test_lamb_op.py | 4 +- tests/ut/python/ops/test_clip_func.py | 127 +++++++++++++++++ tests/ut/python/parallel/test_loss_scale.py | 5 +- 18 files changed, 381 insertions(+), 138 deletions(-) create mode 100644 mindspore/python/mindspore/ops/function/clip_func.py create mode 100644 tests/st/ops/test_ops_clip_func.py create mode 100644 tests/ut/python/ops/test_clip_func.py diff --git a/docs/api/api_python/ops/mindspore.ops.func_clip_by_value.rst b/docs/api/api_python/ops/mindspore.ops.func_clip_by_value.rst index d43ddea3689..baa4d3b3b47 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_clip_by_value.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_clip_by_value.rst @@ -17,13 +17,19 @@ .. note:: - `clip_value_min` 必须小于或等于 `clip_value_max` ; - - :math:`x` , `clip_value_min` 和 `clip_value_max` 的数据类型需支持隐式类型转换,且不能同时为布尔型。 + - :math:`x` , `clip_value_min` 和 `clip_value_max` 的数据类型需支持隐式类型转换,且不能为布尔型。 参数: - - **x** (Tensor) - `clip_by_value` 的输入,任意维度的Tensor。 - - **clip_value_min** (Tensor) - 指定最小值。 - - **clip_value_max** (Tensor) - 指定最大值。 + - **x** (Union(Tensor, list[Tensor], tuple[Tensor])) - `clip_by_value` 的输入,类型为Tensor、Tensor的列表或元组。支持任意维度的Tensor。 + - **clip_value_min** (Union(Tensor, float, int)) - 指定最小值。 + - **clip_value_max** (Union(Tensor, float, int)) - 指定最大值。 返回: - Tensor,表示裁剪后的Tensor。其shape和数据类型和 `x` 相同。 \ No newline at end of file + Tensor、Tensor的列表或元组,表示裁剪后的Tensor。其shape和数据类型和 `x` 相同。 + + 异常: + - **ValueError** - 如果 `clip_value_min` 和 `clip_value_max` 都为None。 + - **TypeError** - 如果 `x` 的数据类型不在Tensor、list[Tensor]或tuple[Tensor]中。 + - **TypeError** - 如果 `clip_value_min` 的数据类型不为None、Tensor、float或int。 + - **TypeError** - 如果 `clip_value_max` 的数据类型不为None、Tensor、float或int。 diff --git a/mindspore/python/mindspore/dataset/engine/offload.py b/mindspore/python/mindspore/dataset/engine/offload.py index 9ac754ed385..1d4be0e08ad 100644 --- a/mindspore/python/mindspore/dataset/engine/offload.py +++ b/mindspore/python/mindspore/dataset/engine/offload.py @@ -19,6 +19,7 @@ import numpy as np import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor +import mindspore.ops as ops import mindspore.nn as nn import mindspore.ops.composite as C from mindspore.ops import operations as P @@ -308,15 +309,15 @@ class RandomColorAdjust(nn.Cell): # Apply brightness x = self.mul(x, br_rand_factor) - x = C.clip_by_value(x, 0.0, 255.0) + x = ops.clip_by_value(x, 0.0, 255.0) # Apply contrast x = self.mul(x, cont_rand_factor) + self.mul((1 - cont_rand_factor), x_gray_mean) - x = C.clip_by_value(x, 0.0, 255.0) + x = ops.clip_by_value(x, 0.0, 255.0) # Apply saturation x = self.mul(x, sat_rand_factor) + self.mul((1 - sat_rand_factor), x_gray) - x = C.clip_by_value(x, 0.0, 255.0) + x = ops.clip_by_value(x, 0.0, 255.0) # Apply Hue Transform # Convert tensor from rgb to hsv @@ -361,7 +362,7 @@ class RandomColorAdjust(nn.Cell): x_rgb = x_rgb + C.repeat_elements(self.expand_dims((v - c), 1), 3, 1) x_rgb = self.transpose(x, (0, 2, 3, 1)) * 255.0 - x_rgb = C.clip_by_value(x, 0.0, 255.0) + x_rgb = ops.clip_by_value(x, 0.0, 255.0) return x_rgb @@ -414,7 +415,7 @@ class RandomSharpness(nn.Cell): x_sharp = self.transpose(x_sharp, (0, 2, 3, 1)) x = self.mul(x, degree_rand_factor) + self.mul((1 - degree_rand_factor), x_sharp) - x = C.clip_by_value(x, 0.0, 255.0) + x = ops.clip_by_value(x, 0.0, 255.0) return x diff --git a/mindspore/python/mindspore/nn/optim/thor.py b/mindspore/python/mindspore/nn/optim/thor.py index 215fe7fda18..4c653d6a8bd 100644 --- a/mindspore/python/mindspore/nn/optim/thor.py +++ b/mindspore/python/mindspore/nn/optim/thor.py @@ -21,6 +21,7 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.tensor import Tensor +import mindspore.ops as ops import mindspore.nn as nn import mindspore.common.dtype as mstype import mindspore.log as logger @@ -87,8 +88,8 @@ def _clip_grad(clip_type, clip_value, grad): return grad dt = F.dtype(grad) if clip_type == 0: - new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), - F.cast(F.tuple_to_array((clip_value,)), dt)) + new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) else: new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) return new_grad diff --git a/mindspore/python/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/python/mindspore/nn/probability/distribution/_utils/utils.py index 9b285d08cb1..1f87f4dc904 100644 --- a/mindspore/python/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/python/mindspore/nn/probability/distribution/_utils/utils.py @@ -19,9 +19,9 @@ from mindspore._checkparam import Validator as validator from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype -from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register +import mindspore.ops as ops import mindspore.nn as nn @@ -214,7 +214,7 @@ def clamp_probs(probs): clamp probs boundary """ eps = P.Eps()(probs) - return C.clip_by_value(probs, eps, 1-eps) + return ops.clip_by_value(probs, eps, 1-eps) def probs_to_logits(probs, is_binary=False): diff --git a/mindspore/python/mindspore/nn/probability/distribution/categorical.py b/mindspore/python/mindspore/nn/probability/distribution/categorical.py index d2e80d3e657..98795c62477 100644 --- a/mindspore/python/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/python/mindspore/nn/probability/distribution/categorical.py @@ -20,6 +20,7 @@ from mindspore.ops import composite as C from mindspore.ops.functional import stop_gradient from mindspore.ops.operations import _inner_ops as inner from mindspore._checkparam import Validator +import mindspore.ops as ops import mindspore.nn as nn from mindspore.common import dtype as mstype from .distribution import Distribution @@ -141,7 +142,7 @@ class Categorical(Distribution): self.argmax = P.ArgMaxWithValue(axis=-1) self.broadcast = broadcast_to self.cast = P.Cast() - self.clip_by_value = C.clip_by_value + self.clip_by_value = ops.clip_by_value self.concat = P.Concat(-1) self.cumsum = P.CumSum() self.dtypeop = P.DType() diff --git a/mindspore/python/mindspore/ops/composite/__init__.py b/mindspore/python/mindspore/ops/composite/__init__.py index 344a094338b..b603db6ea16 100644 --- a/mindspore/python/mindspore/ops/composite/__init__.py +++ b/mindspore/python/mindspore/ops/composite/__init__.py @@ -23,7 +23,7 @@ from __future__ import absolute_import from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \ tail, zip_operation, _Vmap, _TaylorOperation from mindspore.ops.composite.env_ops import env_get -from mindspore.ops.composite.clip_ops import clip_by_value, clip_by_global_norm +from mindspore.ops.composite.clip_ops import clip_by_global_norm from mindspore.ops.composite.multitype_ops.add_impl import hyper_add from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like @@ -31,6 +31,7 @@ from mindspore.ops.composite.random_ops import normal, laplace, uniform, gamma, from mindspore.ops.composite.math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul, cummin from mindspore.ops.composite.array_ops import repeat_interleave, repeat_elements, sequence_mask from mindspore.ops.composite.vmap_ops import _VmapGeneralPreprocess, _VmapGeneralRule +from mindspore.ops.function.clip_func import clip_by_value __all__ = [ @@ -51,7 +52,6 @@ __all__ = [ 'gamma', 'poisson', 'multinomial', - 'clip_by_value', 'clip_by_global_norm', 'count_nonzero', 'cummin', diff --git a/mindspore/python/mindspore/ops/composite/clip_ops.py b/mindspore/python/mindspore/ops/composite/clip_ops.py index 4807428e3ae..5ebb5c75c0e 100644 --- a/mindspore/python/mindspore/ops/composite/clip_ops.py +++ b/mindspore/python/mindspore/ops/composite/clip_ops.py @@ -15,8 +15,6 @@ """Operations for clipping tensors to min/max values.""" from __future__ import absolute_import - -import numpy as np from mindspore.nn.cell import Cell from mindspore.ops import composite as C from mindspore.ops import functional as F @@ -28,108 +26,6 @@ from mindspore._checkparam import Validator as validator from mindspore.ops.primitive import constexpr -@constexpr -def _check_output_shape(input_shape, out_shape, prim_name=None): - msg_prefix = f"For '{prim_name}', the" if prim_name else "The" - if input_shape != out_shape: - raise ValueError(f"{msg_prefix} input 'x' shape must be equal to the output shape, but got " - f"input 'x' shape {input_shape}, output shape {out_shape}.") - - -def check_np_type(np_dtype, is_max_val): - if not (np.issubsctype(np_dtype, np.floating) or np.issubsctype(np_dtype, np.integer) or - np.issubsctype(np_dtype, np.complex64) or np.issubsctype(np_dtype, np.complex128) or - np.issubsctype(np_dtype, np.bool_)): - value_info = ("clip_value_max", "clip_value_min") if is_max_val else ("clip_value_min", "clip_value_max") - raise ValueError(f"When {value_info[0]} is none, The date type of {value_info[1]} only support integer," - f"floating, bool, complex64 or complex128. But got {np_dtype}") - - -@constexpr -def create_max_min_value(ms_type, is_max_val): - """create max or min value""" - np_dtype = mstype.dtype_to_nptype(ms_type) - check_np_type(np_dtype, is_max_val) - if np.issubsctype(np_dtype, np.floating): - val = np.finfo(np_dtype).max if is_max_val else np.finfo(np_dtype).min - elif np.issubsctype(np_dtype, np.integer): - val = np.iinfo(np_dtype).max if is_max_val else np.iinfo(np_dtype).min - elif np.issubsctype(np_dtype, np.complex64): - val = np.finfo(np.float32).max if is_max_val else np.finfo(np.float32).min - val = np.complex64(np.complex(val, val)) - elif np.issubsctype(np_dtype, np.complex128): - val = np.finfo(np.float64).max if is_max_val else np.finfo(np.float64).min - val = np.complex128(np.complex(val, val)) - else: - val = np.bool_(True) if is_max_val else np.bool_(False) - return Tensor(val, ms_type) - - -@constexpr -def raise_value_error(): - raise ValueError("At least one of 'clip_value_min' or 'clip_value_max' must not be None") - - -def clip_by_value(x, clip_value_min=None, clip_value_max=None): - r""" - Clips tensor values to a specified min and max. - - Limits the value of :math:`x` to a range, whose lower limit is `clip_value_min` - and upper limit is `clip_value_max` . - - .. math:: - - out_i= \left\{ - \begin{array}{align} - clip\_value\_max & \text{ if } x_i\ge clip\_value\_max \\ - x_i & \text{ if } clip\_value\_min \lt x_i \lt clip\_value\_max \\ - clip\_value\_min & \text{ if } x_i \le clip\_value\_min \\ - \end{array}\right. - - Note: - `clip_value_min` needs to be less than or equal to `clip_value_max` . The data type of x, `clip_value_min` and - `clip_value_max` should support implicit type conversion and cannot all be bool type. - - Args: - x (Tensor): Input data. The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions. - clip_value_min (Tensor): The minimum value. `clip_value_min` and `clip_value_max` cannot be all None. - Default: None. - clip_value_max (Tensor): The maximum value. `clip_value_min` and `clip_value_max` cannot be all None. - Default: None. - - Returns: - Tensor, a clipped Tensor. The data type is the one with higher precision or higher digits among - the x, `clip_value_min` and `clip_value_max` . - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> from mindspore import Tensor, ops - >>> import numpy as np - >>> min_value = Tensor(5, mindspore.float32) - >>> max_value = Tensor(20, mindspore.float32) - >>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32) - >>> output = ops.clip_by_value(x, min_value, max_value) - >>> print(output) - [[ 5. 20. 5. 7.] - [ 5. 11. 6. 20.]] - """ - - min_op = P.Minimum() - max_op = P.Maximum() - if clip_value_min is None and clip_value_max is None: - raise_value_error() - if clip_value_min is None: - clip_value_min = create_max_min_value(F.dtype(clip_value_max), False) - if clip_value_max is None: - clip_value_max = create_max_min_value(F.dtype(clip_value_min), True) - x_min = min_op(x, clip_value_max) - x_max = max_op(x_min, clip_value_min) - _check_output_shape(F.shape(x), F.shape(x_max), 'clip_by_value') - return x_max - - # The attribute grad_scale is needed for enabling the parallel mode # If this is removed, c.clip_by_global_norm will have precision error in semi/auto parallel mode. expand_dims = P.ExpandDims().add_prim_attr("grad_scale", True) diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 292e643b7a8..a6da3477d4f 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -25,6 +25,7 @@ from . import ( math_func, nn_func, linalg_func, + clip_func, ) from .array_func import ( unique, @@ -534,6 +535,9 @@ from .sparse_unary_func import ( coo_sigmoid, coo_sin ) +from .clip_func import ( + clip_by_value, +) __all__ = [] __all__.extend(array_func.__all__) @@ -549,4 +553,5 @@ __all__.extend(image_func.__all__) __all__.extend(spectral_func.__all__) __all__.extend(vmap_func.__all__) __all__.extend(sparse_unary_func.__all__) +__all__.extend(clip_func.__all__) __all__.sort() diff --git a/mindspore/python/mindspore/ops/function/clip_func.py b/mindspore/python/mindspore/ops/function/clip_func.py new file mode 100644 index 00000000000..9c70a477f2a --- /dev/null +++ b/mindspore/python/mindspore/ops/function/clip_func.py @@ -0,0 +1,128 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Defines clip operators with functional form.""" + +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops._primitive_cache import _get_cache_prim +from mindspore.common.tensor import Tensor + +__all__ = [ + 'clip_by_value', +] + +hyper_map = C.HyperMap() +max_op = _get_cache_prim(P.Maximum)() +min_op = _get_cache_prim(P.Minimum)() +cast_op = _get_cache_prim(P.Cast)() +scalar2tensor_op = _get_cache_prim(P.ScalarToTensor)() +partial_op = _get_cache_prim(P.Partial)() + + +def clip_by_value(x, clip_value_min=None, clip_value_max=None): + r""" + Clips tensor values to a specified min and max. + + Limits the value of :math:`x` to a range, whose lower limit is `clip_value_min` + and upper limit is `clip_value_max` . + + .. math:: + + out_i= \left\{ + \begin{array}{align} + clip\_value\_max & \text{ if } x_i\ge clip\_value\_max \\ + x_i & \text{ if } clip\_value\_min \lt x_i \lt clip\_value\_max \\ + clip\_value\_min & \text{ if } x_i \le clip\_value\_min \\ + \end{array}\right. + + Note: + The data type of `x`, `clip_value_min` and `clip_value_max` should support implicit type conversion and cannot + be bool type. + + Args: + x (Union(Tensor, list[Tensor], tuple[Tensor])): Input data, which type is Tensor or a list or tuple of Tensor. + The shape of Tensor is :math:`(N,*)` where :math:`*` means, + any number of additional dimensions. + clip_value_min (Union(Tensor, float, int)): The minimum value. `clip_value_min` and `clip_value_max` + cannot be all None. Default: None. + clip_value_max (Union(Tensor, float, int)): The maximum value. `clip_value_min` and `clip_value_max` + cannot be all None. Default: None. + + Returns: + (Union(Tensor, tuple[Tensor], list[Tensor])), a clipped Tensor or a tuple or a list of clipped Tensor. + The data type and shape are the same as x. + + Raises: + ValueError: If both `clip_value_min` and `clip_value_max` are None. + TypeError: If the type of `x` is not in Tensor or list[Tensor] or tuple[Tensor]. + TypeError: If the type of `clip_value_min` is not in None, Tensor, float or int. + TypeError: If the type of `clip_value_max` is not in None, Tensor, float or int. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> # case 1: the data type of x is Tensor + >>> from mindspore import Tensor, ops + >>> import numpy as np + >>> min_value = Tensor(5, mindspore.float32) + >>> max_value = Tensor(20, mindspore.float32) + >>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32) + >>> output = ops.clip_by_value(x, min_value, max_value) + >>> print(output) + [[ 5. 20. 5. 7.] + [ 5. 11. 6. 20.]] + >>> # case 2: the data type of x is list[Tensor] + >>> min_value = 5 + >>> max_value = 20 + >>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32) + >>> y = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32) + >>> output = ops.clip_by_value([x,y], min_value, max_value) + >>> print(output) + [[[ 5. 20. 5. 7.] + [ 5. 11. 6. 20.]], + [[ 5. 20. 5. 7.] + [ 5. 11. 6. 20.]]] + """ + def _clip_by_value(clip_min, clip_max, x): + if not isinstance(x, Tensor): + TypeError("Then type of 'x' must be Tensor") + result = x + if clip_min is not None: + result = max_op(result, cast_op(clip_min, x.dtype)) + if clip_max is not None: + result = min_op(result, cast_op(clip_max, x.dtype)) + return result + + if clip_value_min is None and clip_value_max is None: + ValueError("At least one of 'clip_value_min' or 'clip_value_max' must not be None") + if not isinstance(x, (Tensor, tuple, list)): + TypeError("The input of 'clip_by_value' must be tensor or tuple[Tensor] or list[Tensor]") + if not isinstance(clip_value_min, (type(None), Tensor, float, int)): + TypeError("Then type of 'clip_value_min' must be not one of None, Tensor, float, int.") + if not isinstance(clip_value_max, (type(None), Tensor, float, int)): + TypeError("Then type of 'clip_value_max' must be not one of None, Tensor, float, int.") + if isinstance(clip_value_min, (float, int)): + clip_value_min = scalar2tensor_op(clip_value_min) + if isinstance(clip_value_max, (float, int)): + clip_value_max = scalar2tensor_op(clip_value_max) + + if isinstance(x, Tensor): + return _clip_by_value(clip_value_min, clip_value_max, x) + results = hyper_map(partial_op(_clip_by_value, clip_value_min, clip_value_max), x) + if isinstance(x, tuple): + results = tuple(results) + return results diff --git a/tests/st/dynamic_shape/test_dynamic_wenet_ascend.py b/tests/st/dynamic_shape/test_dynamic_wenet_ascend.py index 643301b67c3..446f0e48bb3 100644 --- a/tests/st/dynamic_shape/test_dynamic_wenet_ascend.py +++ b/tests/st/dynamic_shape/test_dynamic_wenet_ascend.py @@ -123,7 +123,7 @@ def _clip_grad(clip_type, clip_value, grad): return grad dt = F.dtype(grad) if clip_type == 0: - new_grad = C.clip_by_value( + new_grad = ops.clip_by_value( grad, F.cast(F.tuple_to_array((-clip_value,)), dt), F.cast(F.tuple_to_array((clip_value,)), dt) ) diff --git a/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py b/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py index d04bc473dc2..fe1331b593b 100644 --- a/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py +++ b/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py @@ -14,6 +14,7 @@ # ============================================================================ """Bert for pretraining.""" import numpy as np +import mindspore.ops as ops import mindspore.nn as nn from mindspore import context from mindspore.common import dtype as mstype @@ -51,8 +52,8 @@ def _clip_grad(clip_type, clip_value, grad): return grad dt = F.dtype(grad) if clip_type == 0: - new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), - F.cast(F.tuple_to_array((clip_value,)), dt)) + new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) else: new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) return new_grad diff --git a/tests/st/networks/models/bert/bert_performance/src/bert_model.py b/tests/st/networks/models/bert/bert_performance/src/bert_model.py index f5972f43e64..3a0674c6e4c 100644 --- a/tests/st/networks/models/bert/bert_performance/src/bert_model.py +++ b/tests/st/networks/models/bert/bert_performance/src/bert_model.py @@ -18,12 +18,12 @@ import copy import math import numpy as np import mindspore.common.dtype as mstype +import mindspore.ops as ops import mindspore.nn as nn import mindspore.ops.functional as F from mindspore.common.initializer import TruncatedNormal, initializer from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore.ops import composite as C from mindspore.ops import operations as P @@ -274,9 +274,9 @@ class RelaPosMatrixGenerator(nn.Cell): transpose_out = self.range_mat(tile_col_out, (length, length)) distance_mat = self.sub(range_mat_out, transpose_out) - distance_mat_clipped = C.clip_by_value(distance_mat, - self._min_relative_position, - self._max_relative_position) + distance_mat_clipped = ops.clip_by_value(distance_mat, + self._min_relative_position, + self._max_relative_position) # Shift values to be >=0. Each integer still uniquely identifies a # relative position difference. diff --git a/tests/st/networks/models/bert/src/bert_for_pre_training.py b/tests/st/networks/models/bert/src/bert_for_pre_training.py index 6e61a63b4de..0a92dbb20c2 100644 --- a/tests/st/networks/models/bert/src/bert_for_pre_training.py +++ b/tests/st/networks/models/bert/src/bert_for_pre_training.py @@ -15,6 +15,7 @@ """Bert for pretraining.""" import numpy as np +import mindspore.ops as ops import mindspore.nn as nn from mindspore.common.initializer import initializer, TruncatedNormal from mindspore.ops import operations as P @@ -53,8 +54,8 @@ def _clip_grad(clip_type, clip_value, grad): return grad dt = F.dtype(grad) if clip_type == 0: - new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), - F.cast(F.tuple_to_array((clip_value,)), dt)) + new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) else: new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) return new_grad diff --git a/tests/st/networks/models/bert/src/bert_model.py b/tests/st/networks/models/bert/src/bert_model.py index e2da7ed2752..51e64c88712 100644 --- a/tests/st/networks/models/bert/src/bert_model.py +++ b/tests/st/networks/models/bert/src/bert_model.py @@ -18,11 +18,11 @@ import math import copy import numpy as np import mindspore.common.dtype as mstype +import mindspore.ops as ops import mindspore.nn as nn import mindspore.ops.functional as F from mindspore.common.initializer import TruncatedNormal, initializer from mindspore.ops import operations as P -from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter @@ -291,9 +291,9 @@ class RelaPosMatrixGenerator(nn.Cell): transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) distance_mat = self.sub(range_mat_out, transpose_out) - distance_mat_clipped = C.clip_by_value(distance_mat, - self._min_relative_position, - self._max_relative_position) + distance_mat_clipped = ops.clip_by_value(distance_mat, + self._min_relative_position, + self._max_relative_position) # Shift values to be >=0. Each integer still uniquely identifies a # relative position difference. diff --git a/tests/st/ops/test_ops_clip_func.py b/tests/st/ops/test_ops_clip_func.py new file mode 100644 index 00000000000..506c8f2b868 --- /dev/null +++ b/tests/st/ops/test_ops_clip_func.py @@ -0,0 +1,75 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor, ops + + +class NetWorkClipByValue(nn.Cell): + def construct(self, x, min_value, max_value): + return ops.clip_by_value(x, min_value, max_value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ops_clip_by_value_tensor(mode): + """ + Feature: ops.clip_by_value + Description: Verify the result of clip_by_value + Expectation: success + """ + ms.set_context(mode=mode) + x = Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32) + net = NetWorkClipByValue() + output = net(x, -0.3, 0.4) + expect_output = [-0.3, 0.4, 0.2349, -0.3, 0.4] + assert np.allclose(output.asnumpy(), expect_output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_clip_by_value_list_tensor(mode): + """ + Feature: ops.clip_by_value + Description: Verify the result of clip_by_value + Expectation: success + """ + ms.set_context(mode=mode) + x1 = Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32) + x2 = Tensor(np.array([0.6035, 0.6959, 0.0150, -0.5766, 0.5432]), ms.float32) + x3 = Tensor(np.array([0.7549, 0.1056, 0.3312, -0.4060, 0.9821]), ms.float32) + net = NetWorkClipByValue() + output = net([x1, x2, x3], -0.3, 0.4) + expect_output = [[-0.3, 0.4, 0.2349, -0.3, 0.4], + [0.4, 0.4, 0.0150, -0.3, 0.4], + [0.4, 0.1056, 0.3312, -0.3, 0.4] + ] + assert np.allclose(output[0].asnumpy(), expect_output[0]) + assert np.allclose(output[1].asnumpy(), expect_output[1]) + assert np.allclose(output[2].asnumpy(), expect_output[2]) diff --git a/tests/st/optimizer/test_lamb_op.py b/tests/st/optimizer/test_lamb_op.py index abc7b86ec23..a310c6f1a1c 100644 --- a/tests/st/optimizer/test_lamb_op.py +++ b/tests/st/optimizer/test_lamb_op.py @@ -16,6 +16,7 @@ import numpy as np import pytest import mindspore.context as context +import mindspore.ops as ops import mindspore.nn as nn from mindspore import Tensor from mindspore.common import dtype as mstype @@ -23,7 +24,6 @@ from mindspore.ops import operations as P from mindspore.common.parameter import Parameter from mindspore.nn import layer from mindspore.ops import functional as F -from mindspore.ops import composite as C from mindspore.ops.operations import _inner_ops as inner num_one = Tensor(np.ones([1]), mstype.float32) @@ -88,7 +88,7 @@ class LambGPUOrigin(nn.Cell): self.op_select(self.op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), ones) tens = self.op_fill(self.op_dtype(trust_ratio), self.op_shape(trust_ratio), 10.0) - trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) + trust_ratio = ops.clip_by_value(trust_ratio, zeros, tens) update = next_mm / (self.op_sqrt(next_vv) + eps) if decay_flag: diff --git a/tests/ut/python/ops/test_clip_func.py b/tests/ut/python/ops/test_clip_func.py new file mode 100644 index 00000000000..6b9113a0c90 --- /dev/null +++ b/tests/ut/python/ops/test_clip_func.py @@ -0,0 +1,127 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_clip_func """ +import functools + +import numpy as np +import mindspore.nn as nn +from mindspore import ops +import mindspore.context as context +from mindspore import Tensor +from tests.mindspore_test_framework.mindspore_test import mindspore_test +from tests.mindspore_test_framework.pipeline.forward.compile_forward \ + import pipeline_for_compile_forward_ge_graph_for_case_by_case_config +from tests.mindspore_test_framework.pipeline.forward.verify_exception \ + import pipeline_for_verify_exception_for_case_by_case_config + +context.set_context(mode=(context.GRAPH_MODE)) + + + +class NetWorkClipByValue(nn.Cell): + __doc__ = ' NetWorkClipByValue definition ' + + def __init__(self): + super(NetWorkClipByValue, self).__init__() + self.clip_func = ops.clip_by_value + + def construct(self, x, min_value, max_value): + return self.clip_func(x, min_value, max_value) + + +class NetWorkClipListByValue(nn.Cell): + __doc__ = ' NetWorkClipListByValue definition ' + + def __init__(self): + super(NetWorkClipListByValue, self).__init__() + self.clip_func = ops.clip_by_value + + def construct(self, x1, x2, x3, min_value, max_value): + return self.clip_func([x1, x2, x3], min_value, max_value) + + +test_case_clip_func = [ + ('ClipbyValue_1', { + 'block': NetWorkClipByValue(), + 'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)), + 2, + 8.0], + 'skip': ['backward']}), + ('ClipbyValue_2', { + 'block': NetWorkClipByValue(), + 'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)), + None, + 8], + 'skip': ['backward']}), + ('ClipbyValue_3', { + 'block': NetWorkClipByValue(), + 'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)), + Tensor(np.array([2]).astype(np.float32)), + Tensor(np.array(8).astype(np.float32))], + 'skip': ['backward']}), + ('ClipListbyValue_1', { + 'block': NetWorkClipListByValue(), + 'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)), + Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)), + Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)), + Tensor(np.array(2).astype(np.float32)), + 8], + 'skip': ['backward']}), +] + + +test_cases_for_verify_exception = [ + ('ClipByValueERR_1', { + 'block': (NetWorkClipByValue(), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)), + Tensor(np.array([2, 3]).astype(np.float32)), + 8]}), + ('ClipByValueERR_2', { + 'block': (NetWorkClipByValue(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)), + '2', + 8]}), + ('ClipByValueERR_3', { + 'block': (NetWorkClipByValue(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)), + 2, + '8']}), + ('ClipByValueERR_4', { + 'block': (NetWorkClipByValue(), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.ones([1, 3, 3]).astype(np.float32)), + None, + None]}), +] + + +@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) +def test_exec(): + """ + Feature: test list of clip function + Description: test case + Expectation: success + """ + context.set_context(mode=(context.GRAPH_MODE)) + return functools.reduce(lambda x, y: x + y, [test_case_clip_func]) + + +@mindspore_test(pipeline_for_verify_exception_for_case_by_case_config) +def test_check_exception(): + """ + Feature: test list getitem exception + Description: test list getitem exception + Expectation: throw errors + """ + return test_cases_for_verify_exception diff --git a/tests/ut/python/parallel/test_loss_scale.py b/tests/ut/python/parallel/test_loss_scale.py index 93bb8b02560..7fd568d724d 100644 --- a/tests/ut/python/parallel/test_loss_scale.py +++ b/tests/ut/python/parallel/test_loss_scale.py @@ -24,6 +24,7 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.nn.optim.momentum import Momentum from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +import mindspore.ops as ops import mindspore.nn as nn from mindspore.train import Model from mindspore.context import ParallelMode @@ -48,8 +49,8 @@ update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, def _clip_grad(clip_type, clip_value, grad): dt = F.dtype(grad) if clip_type == 0: - new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), - F.cast(F.tuple_to_array((clip_value,)), dt)) + new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) else: new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) return new_grad