!44835 [任务 I55EAD] 优化器提供梯度裁剪功能

Merge pull request !44835 from DavidFFFan/grad_clip
This commit is contained in:
i-robot 2022-11-21 01:58:41 +00:00 committed by Gitee
commit 9414c6c558
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
18 changed files with 381 additions and 138 deletions

View File

@ -17,13 +17,19 @@
.. note::
- `clip_value_min` 必须小于或等于 `clip_value_max`
- :math:`x` `clip_value_min``clip_value_max` 的数据类型需支持隐式类型转换,且不能同时为布尔型。
- :math:`x` `clip_value_min``clip_value_max` 的数据类型需支持隐式类型转换,且不能为布尔型。
参数:
- **x** (Tensor) - `clip_by_value` 的输入任意维度的Tensor。
- **clip_value_min** (Tensor) - 指定最小值。
- **clip_value_max** (Tensor) - 指定最大值。
- **x** (Union(Tensor, list[Tensor], tuple[Tensor])) - `clip_by_value` 的输入,类型为Tensor、Tensor的列表或元组。支持任意维度的Tensor。
- **clip_value_min** (Union(Tensor, float, int)) - 指定最小值。
- **clip_value_max** (Union(Tensor, float, int)) - 指定最大值。
返回:
Tensor表示裁剪后的Tensor。其shape和数据类型和 `x` 相同。
Tensor、Tensor的列表或元组表示裁剪后的Tensor。其shape和数据类型和 `x` 相同。
异常:
- **ValueError** - 如果 `clip_value_min``clip_value_max` 都为None。
- **TypeError** - 如果 `x` 的数据类型不在Tensor、list[Tensor]或tuple[Tensor]中。
- **TypeError** - 如果 `clip_value_min` 的数据类型不为None、Tensor、float或int。
- **TypeError** - 如果 `clip_value_max` 的数据类型不为None、Tensor、float或int。

View File

@ -19,6 +19,7 @@ import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.ops.composite as C
from mindspore.ops import operations as P
@ -308,15 +309,15 @@ class RandomColorAdjust(nn.Cell):
# Apply brightness
x = self.mul(x, br_rand_factor)
x = C.clip_by_value(x, 0.0, 255.0)
x = ops.clip_by_value(x, 0.0, 255.0)
# Apply contrast
x = self.mul(x, cont_rand_factor) + self.mul((1 - cont_rand_factor), x_gray_mean)
x = C.clip_by_value(x, 0.0, 255.0)
x = ops.clip_by_value(x, 0.0, 255.0)
# Apply saturation
x = self.mul(x, sat_rand_factor) + self.mul((1 - sat_rand_factor), x_gray)
x = C.clip_by_value(x, 0.0, 255.0)
x = ops.clip_by_value(x, 0.0, 255.0)
# Apply Hue Transform
# Convert tensor from rgb to hsv
@ -361,7 +362,7 @@ class RandomColorAdjust(nn.Cell):
x_rgb = x_rgb + C.repeat_elements(self.expand_dims((v - c), 1), 3, 1)
x_rgb = self.transpose(x, (0, 2, 3, 1)) * 255.0
x_rgb = C.clip_by_value(x, 0.0, 255.0)
x_rgb = ops.clip_by_value(x, 0.0, 255.0)
return x_rgb
@ -414,7 +415,7 @@ class RandomSharpness(nn.Cell):
x_sharp = self.transpose(x_sharp, (0, 2, 3, 1))
x = self.mul(x, degree_rand_factor) + self.mul((1 - degree_rand_factor), x_sharp)
x = C.clip_by_value(x, 0.0, 255.0)
x = ops.clip_by_value(x, 0.0, 255.0)
return x

View File

@ -21,6 +21,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.log as logger
@ -87,8 +88,8 @@ def _clip_grad(clip_type, clip_value, grad):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad

View File

@ -19,9 +19,9 @@ from mindspore._checkparam import Validator as validator
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
import mindspore.ops as ops
import mindspore.nn as nn
@ -214,7 +214,7 @@ def clamp_probs(probs):
clamp probs boundary
"""
eps = P.Eps()(probs)
return C.clip_by_value(probs, eps, 1-eps)
return ops.clip_by_value(probs, eps, 1-eps)
def probs_to_logits(probs, is_binary=False):

View File

@ -20,6 +20,7 @@ from mindspore.ops import composite as C
from mindspore.ops.functional import stop_gradient
from mindspore.ops.operations import _inner_ops as inner
from mindspore._checkparam import Validator
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from .distribution import Distribution
@ -141,7 +142,7 @@ class Categorical(Distribution):
self.argmax = P.ArgMaxWithValue(axis=-1)
self.broadcast = broadcast_to
self.cast = P.Cast()
self.clip_by_value = C.clip_by_value
self.clip_by_value = ops.clip_by_value
self.concat = P.Concat(-1)
self.cumsum = P.CumSum()
self.dtypeop = P.DType()

View File

@ -23,7 +23,7 @@ from __future__ import absolute_import
from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
tail, zip_operation, _Vmap, _TaylorOperation
from mindspore.ops.composite.env_ops import env_get
from mindspore.ops.composite.clip_ops import clip_by_value, clip_by_global_norm
from mindspore.ops.composite.clip_ops import clip_by_global_norm
from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
@ -31,6 +31,7 @@ from mindspore.ops.composite.random_ops import normal, laplace, uniform, gamma,
from mindspore.ops.composite.math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul, cummin
from mindspore.ops.composite.array_ops import repeat_interleave, repeat_elements, sequence_mask
from mindspore.ops.composite.vmap_ops import _VmapGeneralPreprocess, _VmapGeneralRule
from mindspore.ops.function.clip_func import clip_by_value
__all__ = [
@ -51,7 +52,6 @@ __all__ = [
'gamma',
'poisson',
'multinomial',
'clip_by_value',
'clip_by_global_norm',
'count_nonzero',
'cummin',

View File

@ -15,8 +15,6 @@
"""Operations for clipping tensors to min/max values."""
from __future__ import absolute_import
import numpy as np
from mindspore.nn.cell import Cell
from mindspore.ops import composite as C
from mindspore.ops import functional as F
@ -28,108 +26,6 @@ from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
@constexpr
def _check_output_shape(input_shape, out_shape, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if input_shape != out_shape:
raise ValueError(f"{msg_prefix} input 'x' shape must be equal to the output shape, but got "
f"input 'x' shape {input_shape}, output shape {out_shape}.")
def check_np_type(np_dtype, is_max_val):
if not (np.issubsctype(np_dtype, np.floating) or np.issubsctype(np_dtype, np.integer) or
np.issubsctype(np_dtype, np.complex64) or np.issubsctype(np_dtype, np.complex128) or
np.issubsctype(np_dtype, np.bool_)):
value_info = ("clip_value_max", "clip_value_min") if is_max_val else ("clip_value_min", "clip_value_max")
raise ValueError(f"When {value_info[0]} is none, The date type of {value_info[1]} only support integer,"
f"floating, bool, complex64 or complex128. But got {np_dtype}")
@constexpr
def create_max_min_value(ms_type, is_max_val):
"""create max or min value"""
np_dtype = mstype.dtype_to_nptype(ms_type)
check_np_type(np_dtype, is_max_val)
if np.issubsctype(np_dtype, np.floating):
val = np.finfo(np_dtype).max if is_max_val else np.finfo(np_dtype).min
elif np.issubsctype(np_dtype, np.integer):
val = np.iinfo(np_dtype).max if is_max_val else np.iinfo(np_dtype).min
elif np.issubsctype(np_dtype, np.complex64):
val = np.finfo(np.float32).max if is_max_val else np.finfo(np.float32).min
val = np.complex64(np.complex(val, val))
elif np.issubsctype(np_dtype, np.complex128):
val = np.finfo(np.float64).max if is_max_val else np.finfo(np.float64).min
val = np.complex128(np.complex(val, val))
else:
val = np.bool_(True) if is_max_val else np.bool_(False)
return Tensor(val, ms_type)
@constexpr
def raise_value_error():
raise ValueError("At least one of 'clip_value_min' or 'clip_value_max' must not be None")
def clip_by_value(x, clip_value_min=None, clip_value_max=None):
r"""
Clips tensor values to a specified min and max.
Limits the value of :math:`x` to a range, whose lower limit is `clip_value_min`
and upper limit is `clip_value_max` .
.. math::
out_i= \left\{
\begin{array}{align}
clip\_value\_max & \text{ if } x_i\ge clip\_value\_max \\
x_i & \text{ if } clip\_value\_min \lt x_i \lt clip\_value\_max \\
clip\_value\_min & \text{ if } x_i \le clip\_value\_min \\
\end{array}\right.
Note:
`clip_value_min` needs to be less than or equal to `clip_value_max` . The data type of x, `clip_value_min` and
`clip_value_max` should support implicit type conversion and cannot all be bool type.
Args:
x (Tensor): Input data. The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
clip_value_min (Tensor): The minimum value. `clip_value_min` and `clip_value_max` cannot be all None.
Default: None.
clip_value_max (Tensor): The maximum value. `clip_value_min` and `clip_value_max` cannot be all None.
Default: None.
Returns:
Tensor, a clipped Tensor. The data type is the one with higher precision or higher digits among
the x, `clip_value_min` and `clip_value_max` .
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor, ops
>>> import numpy as np
>>> min_value = Tensor(5, mindspore.float32)
>>> max_value = Tensor(20, mindspore.float32)
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
>>> output = ops.clip_by_value(x, min_value, max_value)
>>> print(output)
[[ 5. 20. 5. 7.]
[ 5. 11. 6. 20.]]
"""
min_op = P.Minimum()
max_op = P.Maximum()
if clip_value_min is None and clip_value_max is None:
raise_value_error()
if clip_value_min is None:
clip_value_min = create_max_min_value(F.dtype(clip_value_max), False)
if clip_value_max is None:
clip_value_max = create_max_min_value(F.dtype(clip_value_min), True)
x_min = min_op(x, clip_value_max)
x_max = max_op(x_min, clip_value_min)
_check_output_shape(F.shape(x), F.shape(x_max), 'clip_by_value')
return x_max
# The attribute grad_scale is needed for enabling the parallel mode
# If this is removed, c.clip_by_global_norm will have precision error in semi/auto parallel mode.
expand_dims = P.ExpandDims().add_prim_attr("grad_scale", True)

View File

@ -25,6 +25,7 @@ from . import (
math_func,
nn_func,
linalg_func,
clip_func,
)
from .array_func import (
unique,
@ -533,6 +534,9 @@ from .sparse_unary_func import (
coo_sigmoid,
coo_sin
)
from .clip_func import (
clip_by_value,
)
__all__ = []
__all__.extend(array_func.__all__)
@ -548,4 +552,5 @@ __all__.extend(image_func.__all__)
__all__.extend(spectral_func.__all__)
__all__.extend(vmap_func.__all__)
__all__.extend(sparse_unary_func.__all__)
__all__.extend(clip_func.__all__)
__all__.sort()

View File

@ -0,0 +1,128 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Defines clip operators with functional form."""
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.common.tensor import Tensor
__all__ = [
'clip_by_value',
]
hyper_map = C.HyperMap()
max_op = _get_cache_prim(P.Maximum)()
min_op = _get_cache_prim(P.Minimum)()
cast_op = _get_cache_prim(P.Cast)()
scalar2tensor_op = _get_cache_prim(P.ScalarToTensor)()
partial_op = _get_cache_prim(P.Partial)()
def clip_by_value(x, clip_value_min=None, clip_value_max=None):
r"""
Clips tensor values to a specified min and max.
Limits the value of :math:`x` to a range, whose lower limit is `clip_value_min`
and upper limit is `clip_value_max` .
.. math::
out_i= \left\{
\begin{array}{align}
clip\_value\_max & \text{ if } x_i\ge clip\_value\_max \\
x_i & \text{ if } clip\_value\_min \lt x_i \lt clip\_value\_max \\
clip\_value\_min & \text{ if } x_i \le clip\_value\_min \\
\end{array}\right.
Note:
The data type of `x`, `clip_value_min` and `clip_value_max` should support implicit type conversion and cannot
be bool type.
Args:
x (Union(Tensor, list[Tensor], tuple[Tensor])): Input data, which type is Tensor or a list or tuple of Tensor.
The shape of Tensor is :math:`(N,*)` where :math:`*` means,
any number of additional dimensions.
clip_value_min (Union(Tensor, float, int)): The minimum value. `clip_value_min` and `clip_value_max`
cannot be all None. Default: None.
clip_value_max (Union(Tensor, float, int)): The maximum value. `clip_value_min` and `clip_value_max`
cannot be all None. Default: None.
Returns:
(Union(Tensor, tuple[Tensor], list[Tensor])), a clipped Tensor or a tuple or a list of clipped Tensor.
The data type and shape are the same as x.
Raises:
ValueError: If both `clip_value_min` and `clip_value_max` are None.
TypeError: If the type of `x` is not in Tensor or list[Tensor] or tuple[Tensor].
TypeError: If the type of `clip_value_min` is not in None, Tensor, float or int.
TypeError: If the type of `clip_value_max` is not in None, Tensor, float or int.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> # case 1: the data type of x is Tensor
>>> from mindspore import Tensor, ops
>>> import numpy as np
>>> min_value = Tensor(5, mindspore.float32)
>>> max_value = Tensor(20, mindspore.float32)
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
>>> output = ops.clip_by_value(x, min_value, max_value)
>>> print(output)
[[ 5. 20. 5. 7.]
[ 5. 11. 6. 20.]]
>>> # case 2: the data type of x is list[Tensor]
>>> min_value = 5
>>> max_value = 20
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
>>> y = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
>>> output = ops.clip_by_value([x,y], min_value, max_value)
>>> print(output)
[[[ 5. 20. 5. 7.]
[ 5. 11. 6. 20.]],
[[ 5. 20. 5. 7.]
[ 5. 11. 6. 20.]]]
"""
def _clip_by_value(clip_min, clip_max, x):
if not isinstance(x, Tensor):
TypeError("Then type of 'x' must be Tensor")
result = x
if clip_min is not None:
result = max_op(result, cast_op(clip_min, x.dtype))
if clip_max is not None:
result = min_op(result, cast_op(clip_max, x.dtype))
return result
if clip_value_min is None and clip_value_max is None:
ValueError("At least one of 'clip_value_min' or 'clip_value_max' must not be None")
if not isinstance(x, (Tensor, tuple, list)):
TypeError("The input of 'clip_by_value' must be tensor or tuple[Tensor] or list[Tensor]")
if not isinstance(clip_value_min, (type(None), Tensor, float, int)):
TypeError("Then type of 'clip_value_min' must be not one of None, Tensor, float, int.")
if not isinstance(clip_value_max, (type(None), Tensor, float, int)):
TypeError("Then type of 'clip_value_max' must be not one of None, Tensor, float, int.")
if isinstance(clip_value_min, (float, int)):
clip_value_min = scalar2tensor_op(clip_value_min)
if isinstance(clip_value_max, (float, int)):
clip_value_max = scalar2tensor_op(clip_value_max)
if isinstance(x, Tensor):
return _clip_by_value(clip_value_min, clip_value_max, x)
results = hyper_map(partial_op(_clip_by_value, clip_value_min, clip_value_max), x)
if isinstance(x, tuple):
results = tuple(results)
return results

View File

@ -123,7 +123,7 @@ def _clip_grad(clip_type, clip_value, grad):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(
new_grad = ops.clip_by_value(
grad, F.cast(F.tuple_to_array((-clip_value,)),
dt), F.cast(F.tuple_to_array((clip_value,)), dt)
)

View File

@ -14,6 +14,7 @@
# ============================================================================
"""Bert for pretraining."""
import numpy as np
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import context
from mindspore.common import dtype as mstype
@ -51,8 +52,8 @@ def _clip_grad(clip_type, clip_value, grad):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad

View File

@ -18,12 +18,12 @@ import copy
import math
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
@ -274,9 +274,9 @@ class RelaPosMatrixGenerator(nn.Cell):
transpose_out = self.range_mat(tile_col_out, (length, length))
distance_mat = self.sub(range_mat_out, transpose_out)
distance_mat_clipped = C.clip_by_value(distance_mat,
self._min_relative_position,
self._max_relative_position)
distance_mat_clipped = ops.clip_by_value(distance_mat,
self._min_relative_position,
self._max_relative_position)
# Shift values to be >=0. Each integer still uniquely identifies a
# relative position difference.

View File

@ -15,6 +15,7 @@
"""Bert for pretraining."""
import numpy as np
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.common.initializer import initializer, TruncatedNormal
from mindspore.ops import operations as P
@ -53,8 +54,8 @@ def _clip_grad(clip_type, clip_value, grad):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad

View File

@ -18,11 +18,11 @@ import math
import copy
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
@ -291,9 +291,9 @@ class RelaPosMatrixGenerator(nn.Cell):
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
distance_mat = self.sub(range_mat_out, transpose_out)
distance_mat_clipped = C.clip_by_value(distance_mat,
self._min_relative_position,
self._max_relative_position)
distance_mat_clipped = ops.clip_by_value(distance_mat,
self._min_relative_position,
self._max_relative_position)
# Shift values to be >=0. Each integer still uniquely identifies a
# relative position difference.

View File

@ -0,0 +1,75 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, ops
class NetWorkClipByValue(nn.Cell):
def construct(self, x, min_value, max_value):
return ops.clip_by_value(x, min_value, max_value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_ops_clip_by_value_tensor(mode):
"""
Feature: ops.clip_by_value
Description: Verify the result of clip_by_value
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32)
net = NetWorkClipByValue()
output = net(x, -0.3, 0.4)
expect_output = [-0.3, 0.4, 0.2349, -0.3, 0.4]
assert np.allclose(output.asnumpy(), expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_clip_by_value_list_tensor(mode):
"""
Feature: ops.clip_by_value
Description: Verify the result of clip_by_value
Expectation: success
"""
ms.set_context(mode=mode)
x1 = Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32)
x2 = Tensor(np.array([0.6035, 0.6959, 0.0150, -0.5766, 0.5432]), ms.float32)
x3 = Tensor(np.array([0.7549, 0.1056, 0.3312, -0.4060, 0.9821]), ms.float32)
net = NetWorkClipByValue()
output = net([x1, x2, x3], -0.3, 0.4)
expect_output = [[-0.3, 0.4, 0.2349, -0.3, 0.4],
[0.4, 0.4, 0.0150, -0.3, 0.4],
[0.4, 0.1056, 0.3312, -0.3, 0.4]
]
assert np.allclose(output[0].asnumpy(), expect_output[0])
assert np.allclose(output[1].asnumpy(), expect_output[1])
assert np.allclose(output[2].asnumpy(), expect_output[2])

View File

@ -16,6 +16,7 @@ import numpy as np
import pytest
import mindspore.context as context
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
@ -23,7 +24,6 @@ from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.nn import layer
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops.operations import _inner_ops as inner
num_one = Tensor(np.ones([1]), mstype.float32)
@ -88,7 +88,7 @@ class LambGPUOrigin(nn.Cell):
self.op_select(self.op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = self.op_fill(self.op_dtype(trust_ratio), self.op_shape(trust_ratio), 10.0)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
trust_ratio = ops.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (self.op_sqrt(next_vv) + eps)
if decay_flag:

View File

@ -0,0 +1,127 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_clip_func """
import functools
import numpy as np
import mindspore.nn as nn
from mindspore import ops
import mindspore.context as context
from mindspore import Tensor
from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from tests.mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config
context.set_context(mode=(context.GRAPH_MODE))
class NetWorkClipByValue(nn.Cell):
__doc__ = ' NetWorkClipByValue definition '
def __init__(self):
super(NetWorkClipByValue, self).__init__()
self.clip_func = ops.clip_by_value
def construct(self, x, min_value, max_value):
return self.clip_func(x, min_value, max_value)
class NetWorkClipListByValue(nn.Cell):
__doc__ = ' NetWorkClipListByValue definition '
def __init__(self):
super(NetWorkClipListByValue, self).__init__()
self.clip_func = ops.clip_by_value
def construct(self, x1, x2, x3, min_value, max_value):
return self.clip_func([x1, x2, x3], min_value, max_value)
test_case_clip_func = [
('ClipbyValue_1', {
'block': NetWorkClipByValue(),
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
2,
8.0],
'skip': ['backward']}),
('ClipbyValue_2', {
'block': NetWorkClipByValue(),
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
None,
8],
'skip': ['backward']}),
('ClipbyValue_3', {
'block': NetWorkClipByValue(),
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
Tensor(np.array([2]).astype(np.float32)),
Tensor(np.array(8).astype(np.float32))],
'skip': ['backward']}),
('ClipListbyValue_1', {
'block': NetWorkClipListByValue(),
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
Tensor(np.array(2).astype(np.float32)),
8],
'skip': ['backward']}),
]
test_cases_for_verify_exception = [
('ClipByValueERR_1', {
'block': (NetWorkClipByValue(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
Tensor(np.array([2, 3]).astype(np.float32)),
8]}),
('ClipByValueERR_2', {
'block': (NetWorkClipByValue(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
'2',
8]}),
('ClipByValueERR_3', {
'block': (NetWorkClipByValue(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
2,
'8']}),
('ClipByValueERR_4', {
'block': (NetWorkClipByValue(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones([1, 3, 3]).astype(np.float32)),
None,
None]}),
]
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_exec():
"""
Feature: test list of clip function
Description: test case
Expectation: success
"""
context.set_context(mode=(context.GRAPH_MODE))
return functools.reduce(lambda x, y: x + y, [test_case_clip_func])
@mindspore_test(pipeline_for_verify_exception_for_case_by_case_config)
def test_check_exception():
"""
Feature: test list getitem exception
Description: test list getitem exception
Expectation: throw errors
"""
return test_cases_for_verify_exception

View File

@ -24,6 +24,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.train import Model
from mindspore.context import ParallelMode
@ -48,8 +49,8 @@ update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2,
def _clip_grad(clip_type, clip_value, grad):
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad