!44835 [任务 I55EAD] 优化器提供梯度裁剪功能
Merge pull request !44835 from DavidFFFan/grad_clip
This commit is contained in:
commit
9414c6c558
|
@ -17,13 +17,19 @@
|
|||
|
||||
.. note::
|
||||
- `clip_value_min` 必须小于或等于 `clip_value_max` ;
|
||||
- :math:`x` , `clip_value_min` 和 `clip_value_max` 的数据类型需支持隐式类型转换,且不能同时为布尔型。
|
||||
- :math:`x` , `clip_value_min` 和 `clip_value_max` 的数据类型需支持隐式类型转换,且不能为布尔型。
|
||||
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - `clip_by_value` 的输入,任意维度的Tensor。
|
||||
- **clip_value_min** (Tensor) - 指定最小值。
|
||||
- **clip_value_max** (Tensor) - 指定最大值。
|
||||
- **x** (Union(Tensor, list[Tensor], tuple[Tensor])) - `clip_by_value` 的输入,类型为Tensor、Tensor的列表或元组。支持任意维度的Tensor。
|
||||
- **clip_value_min** (Union(Tensor, float, int)) - 指定最小值。
|
||||
- **clip_value_max** (Union(Tensor, float, int)) - 指定最大值。
|
||||
|
||||
返回:
|
||||
Tensor,表示裁剪后的Tensor。其shape和数据类型和 `x` 相同。
|
||||
Tensor、Tensor的列表或元组,表示裁剪后的Tensor。其shape和数据类型和 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **ValueError** - 如果 `clip_value_min` 和 `clip_value_max` 都为None。
|
||||
- **TypeError** - 如果 `x` 的数据类型不在Tensor、list[Tensor]或tuple[Tensor]中。
|
||||
- **TypeError** - 如果 `clip_value_min` 的数据类型不为None、Tensor、float或int。
|
||||
- **TypeError** - 如果 `clip_value_max` 的数据类型不为None、Tensor、float或int。
|
||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.composite as C
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -308,15 +309,15 @@ class RandomColorAdjust(nn.Cell):
|
|||
|
||||
# Apply brightness
|
||||
x = self.mul(x, br_rand_factor)
|
||||
x = C.clip_by_value(x, 0.0, 255.0)
|
||||
x = ops.clip_by_value(x, 0.0, 255.0)
|
||||
|
||||
# Apply contrast
|
||||
x = self.mul(x, cont_rand_factor) + self.mul((1 - cont_rand_factor), x_gray_mean)
|
||||
x = C.clip_by_value(x, 0.0, 255.0)
|
||||
x = ops.clip_by_value(x, 0.0, 255.0)
|
||||
|
||||
# Apply saturation
|
||||
x = self.mul(x, sat_rand_factor) + self.mul((1 - sat_rand_factor), x_gray)
|
||||
x = C.clip_by_value(x, 0.0, 255.0)
|
||||
x = ops.clip_by_value(x, 0.0, 255.0)
|
||||
|
||||
# Apply Hue Transform
|
||||
# Convert tensor from rgb to hsv
|
||||
|
@ -361,7 +362,7 @@ class RandomColorAdjust(nn.Cell):
|
|||
x_rgb = x_rgb + C.repeat_elements(self.expand_dims((v - c), 1), 3, 1)
|
||||
|
||||
x_rgb = self.transpose(x, (0, 2, 3, 1)) * 255.0
|
||||
x_rgb = C.clip_by_value(x, 0.0, 255.0)
|
||||
x_rgb = ops.clip_by_value(x, 0.0, 255.0)
|
||||
|
||||
return x_rgb
|
||||
|
||||
|
@ -414,7 +415,7 @@ class RandomSharpness(nn.Cell):
|
|||
x_sharp = self.transpose(x_sharp, (0, 2, 3, 1))
|
||||
|
||||
x = self.mul(x, degree_rand_factor) + self.mul((1 - degree_rand_factor), x_sharp)
|
||||
x = C.clip_by_value(x, 0.0, 255.0)
|
||||
x = ops.clip_by_value(x, 0.0, 255.0)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
|
|||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.log as logger
|
||||
|
@ -87,8 +88,8 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
|
|
@ -19,9 +19,9 @@ from mindspore._checkparam import Validator as validator
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
|
@ -214,7 +214,7 @@ def clamp_probs(probs):
|
|||
clamp probs boundary
|
||||
"""
|
||||
eps = P.Eps()(probs)
|
||||
return C.clip_by_value(probs, eps, 1-eps)
|
||||
return ops.clip_by_value(probs, eps, 1-eps)
|
||||
|
||||
|
||||
def probs_to_logits(probs, is_binary=False):
|
||||
|
|
|
@ -20,6 +20,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops.functional import stop_gradient
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore._checkparam import Validator
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
|
@ -141,7 +142,7 @@ class Categorical(Distribution):
|
|||
self.argmax = P.ArgMaxWithValue(axis=-1)
|
||||
self.broadcast = broadcast_to
|
||||
self.cast = P.Cast()
|
||||
self.clip_by_value = C.clip_by_value
|
||||
self.clip_by_value = ops.clip_by_value
|
||||
self.concat = P.Concat(-1)
|
||||
self.cumsum = P.CumSum()
|
||||
self.dtypeop = P.DType()
|
||||
|
|
|
@ -23,7 +23,7 @@ from __future__ import absolute_import
|
|||
from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
|
||||
tail, zip_operation, _Vmap, _TaylorOperation
|
||||
from mindspore.ops.composite.env_ops import env_get
|
||||
from mindspore.ops.composite.clip_ops import clip_by_value, clip_by_global_norm
|
||||
from mindspore.ops.composite.clip_ops import clip_by_global_norm
|
||||
from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
|
||||
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
|
||||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
|
@ -31,6 +31,7 @@ from mindspore.ops.composite.random_ops import normal, laplace, uniform, gamma,
|
|||
from mindspore.ops.composite.math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul, cummin
|
||||
from mindspore.ops.composite.array_ops import repeat_interleave, repeat_elements, sequence_mask
|
||||
from mindspore.ops.composite.vmap_ops import _VmapGeneralPreprocess, _VmapGeneralRule
|
||||
from mindspore.ops.function.clip_func import clip_by_value
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -51,7 +52,6 @@ __all__ = [
|
|||
'gamma',
|
||||
'poisson',
|
||||
'multinomial',
|
||||
'clip_by_value',
|
||||
'clip_by_global_norm',
|
||||
'count_nonzero',
|
||||
'cummin',
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
|
||||
"""Operations for clipping tensors to min/max values."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -28,108 +26,6 @@ from mindspore._checkparam import Validator as validator
|
|||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_output_shape(input_shape, out_shape, prim_name=None):
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
if input_shape != out_shape:
|
||||
raise ValueError(f"{msg_prefix} input 'x' shape must be equal to the output shape, but got "
|
||||
f"input 'x' shape {input_shape}, output shape {out_shape}.")
|
||||
|
||||
|
||||
def check_np_type(np_dtype, is_max_val):
|
||||
if not (np.issubsctype(np_dtype, np.floating) or np.issubsctype(np_dtype, np.integer) or
|
||||
np.issubsctype(np_dtype, np.complex64) or np.issubsctype(np_dtype, np.complex128) or
|
||||
np.issubsctype(np_dtype, np.bool_)):
|
||||
value_info = ("clip_value_max", "clip_value_min") if is_max_val else ("clip_value_min", "clip_value_max")
|
||||
raise ValueError(f"When {value_info[0]} is none, The date type of {value_info[1]} only support integer,"
|
||||
f"floating, bool, complex64 or complex128. But got {np_dtype}")
|
||||
|
||||
|
||||
@constexpr
|
||||
def create_max_min_value(ms_type, is_max_val):
|
||||
"""create max or min value"""
|
||||
np_dtype = mstype.dtype_to_nptype(ms_type)
|
||||
check_np_type(np_dtype, is_max_val)
|
||||
if np.issubsctype(np_dtype, np.floating):
|
||||
val = np.finfo(np_dtype).max if is_max_val else np.finfo(np_dtype).min
|
||||
elif np.issubsctype(np_dtype, np.integer):
|
||||
val = np.iinfo(np_dtype).max if is_max_val else np.iinfo(np_dtype).min
|
||||
elif np.issubsctype(np_dtype, np.complex64):
|
||||
val = np.finfo(np.float32).max if is_max_val else np.finfo(np.float32).min
|
||||
val = np.complex64(np.complex(val, val))
|
||||
elif np.issubsctype(np_dtype, np.complex128):
|
||||
val = np.finfo(np.float64).max if is_max_val else np.finfo(np.float64).min
|
||||
val = np.complex128(np.complex(val, val))
|
||||
else:
|
||||
val = np.bool_(True) if is_max_val else np.bool_(False)
|
||||
return Tensor(val, ms_type)
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_value_error():
|
||||
raise ValueError("At least one of 'clip_value_min' or 'clip_value_max' must not be None")
|
||||
|
||||
|
||||
def clip_by_value(x, clip_value_min=None, clip_value_max=None):
|
||||
r"""
|
||||
Clips tensor values to a specified min and max.
|
||||
|
||||
Limits the value of :math:`x` to a range, whose lower limit is `clip_value_min`
|
||||
and upper limit is `clip_value_max` .
|
||||
|
||||
.. math::
|
||||
|
||||
out_i= \left\{
|
||||
\begin{array}{align}
|
||||
clip\_value\_max & \text{ if } x_i\ge clip\_value\_max \\
|
||||
x_i & \text{ if } clip\_value\_min \lt x_i \lt clip\_value\_max \\
|
||||
clip\_value\_min & \text{ if } x_i \le clip\_value\_min \\
|
||||
\end{array}\right.
|
||||
|
||||
Note:
|
||||
`clip_value_min` needs to be less than or equal to `clip_value_max` . The data type of x, `clip_value_min` and
|
||||
`clip_value_max` should support implicit type conversion and cannot all be bool type.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input data. The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
||||
clip_value_min (Tensor): The minimum value. `clip_value_min` and `clip_value_max` cannot be all None.
|
||||
Default: None.
|
||||
clip_value_max (Tensor): The maximum value. `clip_value_min` and `clip_value_max` cannot be all None.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, a clipped Tensor. The data type is the one with higher precision or higher digits among
|
||||
the x, `clip_value_min` and `clip_value_max` .
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> import numpy as np
|
||||
>>> min_value = Tensor(5, mindspore.float32)
|
||||
>>> max_value = Tensor(20, mindspore.float32)
|
||||
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
|
||||
>>> output = ops.clip_by_value(x, min_value, max_value)
|
||||
>>> print(output)
|
||||
[[ 5. 20. 5. 7.]
|
||||
[ 5. 11. 6. 20.]]
|
||||
"""
|
||||
|
||||
min_op = P.Minimum()
|
||||
max_op = P.Maximum()
|
||||
if clip_value_min is None and clip_value_max is None:
|
||||
raise_value_error()
|
||||
if clip_value_min is None:
|
||||
clip_value_min = create_max_min_value(F.dtype(clip_value_max), False)
|
||||
if clip_value_max is None:
|
||||
clip_value_max = create_max_min_value(F.dtype(clip_value_min), True)
|
||||
x_min = min_op(x, clip_value_max)
|
||||
x_max = max_op(x_min, clip_value_min)
|
||||
_check_output_shape(F.shape(x), F.shape(x_max), 'clip_by_value')
|
||||
return x_max
|
||||
|
||||
|
||||
# The attribute grad_scale is needed for enabling the parallel mode
|
||||
# If this is removed, c.clip_by_global_norm will have precision error in semi/auto parallel mode.
|
||||
expand_dims = P.ExpandDims().add_prim_attr("grad_scale", True)
|
||||
|
|
|
@ -25,6 +25,7 @@ from . import (
|
|||
math_func,
|
||||
nn_func,
|
||||
linalg_func,
|
||||
clip_func,
|
||||
)
|
||||
from .array_func import (
|
||||
unique,
|
||||
|
@ -533,6 +534,9 @@ from .sparse_unary_func import (
|
|||
coo_sigmoid,
|
||||
coo_sin
|
||||
)
|
||||
from .clip_func import (
|
||||
clip_by_value,
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(array_func.__all__)
|
||||
|
@ -548,4 +552,5 @@ __all__.extend(image_func.__all__)
|
|||
__all__.extend(spectral_func.__all__)
|
||||
__all__.extend(vmap_func.__all__)
|
||||
__all__.extend(sparse_unary_func.__all__)
|
||||
__all__.extend(clip_func.__all__)
|
||||
__all__.sort()
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Defines clip operators with functional form."""
|
||||
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops._primitive_cache import _get_cache_prim
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
__all__ = [
|
||||
'clip_by_value',
|
||||
]
|
||||
|
||||
hyper_map = C.HyperMap()
|
||||
max_op = _get_cache_prim(P.Maximum)()
|
||||
min_op = _get_cache_prim(P.Minimum)()
|
||||
cast_op = _get_cache_prim(P.Cast)()
|
||||
scalar2tensor_op = _get_cache_prim(P.ScalarToTensor)()
|
||||
partial_op = _get_cache_prim(P.Partial)()
|
||||
|
||||
|
||||
def clip_by_value(x, clip_value_min=None, clip_value_max=None):
|
||||
r"""
|
||||
Clips tensor values to a specified min and max.
|
||||
|
||||
Limits the value of :math:`x` to a range, whose lower limit is `clip_value_min`
|
||||
and upper limit is `clip_value_max` .
|
||||
|
||||
.. math::
|
||||
|
||||
out_i= \left\{
|
||||
\begin{array}{align}
|
||||
clip\_value\_max & \text{ if } x_i\ge clip\_value\_max \\
|
||||
x_i & \text{ if } clip\_value\_min \lt x_i \lt clip\_value\_max \\
|
||||
clip\_value\_min & \text{ if } x_i \le clip\_value\_min \\
|
||||
\end{array}\right.
|
||||
|
||||
Note:
|
||||
The data type of `x`, `clip_value_min` and `clip_value_max` should support implicit type conversion and cannot
|
||||
be bool type.
|
||||
|
||||
Args:
|
||||
x (Union(Tensor, list[Tensor], tuple[Tensor])): Input data, which type is Tensor or a list or tuple of Tensor.
|
||||
The shape of Tensor is :math:`(N,*)` where :math:`*` means,
|
||||
any number of additional dimensions.
|
||||
clip_value_min (Union(Tensor, float, int)): The minimum value. `clip_value_min` and `clip_value_max`
|
||||
cannot be all None. Default: None.
|
||||
clip_value_max (Union(Tensor, float, int)): The maximum value. `clip_value_min` and `clip_value_max`
|
||||
cannot be all None. Default: None.
|
||||
|
||||
Returns:
|
||||
(Union(Tensor, tuple[Tensor], list[Tensor])), a clipped Tensor or a tuple or a list of clipped Tensor.
|
||||
The data type and shape are the same as x.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `clip_value_min` and `clip_value_max` are None.
|
||||
TypeError: If the type of `x` is not in Tensor or list[Tensor] or tuple[Tensor].
|
||||
TypeError: If the type of `clip_value_min` is not in None, Tensor, float or int.
|
||||
TypeError: If the type of `clip_value_max` is not in None, Tensor, float or int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> # case 1: the data type of x is Tensor
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> import numpy as np
|
||||
>>> min_value = Tensor(5, mindspore.float32)
|
||||
>>> max_value = Tensor(20, mindspore.float32)
|
||||
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
|
||||
>>> output = ops.clip_by_value(x, min_value, max_value)
|
||||
>>> print(output)
|
||||
[[ 5. 20. 5. 7.]
|
||||
[ 5. 11. 6. 20.]]
|
||||
>>> # case 2: the data type of x is list[Tensor]
|
||||
>>> min_value = 5
|
||||
>>> max_value = 20
|
||||
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
|
||||
>>> y = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
|
||||
>>> output = ops.clip_by_value([x,y], min_value, max_value)
|
||||
>>> print(output)
|
||||
[[[ 5. 20. 5. 7.]
|
||||
[ 5. 11. 6. 20.]],
|
||||
[[ 5. 20. 5. 7.]
|
||||
[ 5. 11. 6. 20.]]]
|
||||
"""
|
||||
def _clip_by_value(clip_min, clip_max, x):
|
||||
if not isinstance(x, Tensor):
|
||||
TypeError("Then type of 'x' must be Tensor")
|
||||
result = x
|
||||
if clip_min is not None:
|
||||
result = max_op(result, cast_op(clip_min, x.dtype))
|
||||
if clip_max is not None:
|
||||
result = min_op(result, cast_op(clip_max, x.dtype))
|
||||
return result
|
||||
|
||||
if clip_value_min is None and clip_value_max is None:
|
||||
ValueError("At least one of 'clip_value_min' or 'clip_value_max' must not be None")
|
||||
if not isinstance(x, (Tensor, tuple, list)):
|
||||
TypeError("The input of 'clip_by_value' must be tensor or tuple[Tensor] or list[Tensor]")
|
||||
if not isinstance(clip_value_min, (type(None), Tensor, float, int)):
|
||||
TypeError("Then type of 'clip_value_min' must be not one of None, Tensor, float, int.")
|
||||
if not isinstance(clip_value_max, (type(None), Tensor, float, int)):
|
||||
TypeError("Then type of 'clip_value_max' must be not one of None, Tensor, float, int.")
|
||||
if isinstance(clip_value_min, (float, int)):
|
||||
clip_value_min = scalar2tensor_op(clip_value_min)
|
||||
if isinstance(clip_value_max, (float, int)):
|
||||
clip_value_max = scalar2tensor_op(clip_value_max)
|
||||
|
||||
if isinstance(x, Tensor):
|
||||
return _clip_by_value(clip_value_min, clip_value_max, x)
|
||||
results = hyper_map(partial_op(_clip_by_value, clip_value_min, clip_value_max), x)
|
||||
if isinstance(x, tuple):
|
||||
results = tuple(results)
|
||||
return results
|
|
@ -123,7 +123,7 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(
|
||||
new_grad = ops.clip_by_value(
|
||||
grad, F.cast(F.tuple_to_array((-clip_value,)),
|
||||
dt), F.cast(F.tuple_to_array((clip_value,)), dt)
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Bert for pretraining."""
|
||||
import numpy as np
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -51,8 +52,8 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
|
|
@ -18,12 +18,12 @@ import copy
|
|||
import math
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
|
@ -274,9 +274,9 @@ class RelaPosMatrixGenerator(nn.Cell):
|
|||
transpose_out = self.range_mat(tile_col_out, (length, length))
|
||||
distance_mat = self.sub(range_mat_out, transpose_out)
|
||||
|
||||
distance_mat_clipped = C.clip_by_value(distance_mat,
|
||||
self._min_relative_position,
|
||||
self._max_relative_position)
|
||||
distance_mat_clipped = ops.clip_by_value(distance_mat,
|
||||
self._min_relative_position,
|
||||
self._max_relative_position)
|
||||
|
||||
# Shift values to be >=0. Each integer still uniquely identifies a
|
||||
# relative position difference.
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Bert for pretraining."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -53,8 +54,8 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
|
|
@ -18,11 +18,11 @@ import math
|
|||
import copy
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
@ -291,9 +291,9 @@ class RelaPosMatrixGenerator(nn.Cell):
|
|||
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
|
||||
distance_mat = self.sub(range_mat_out, transpose_out)
|
||||
|
||||
distance_mat_clipped = C.clip_by_value(distance_mat,
|
||||
self._min_relative_position,
|
||||
self._max_relative_position)
|
||||
distance_mat_clipped = ops.clip_by_value(distance_mat,
|
||||
self._min_relative_position,
|
||||
self._max_relative_position)
|
||||
|
||||
# Shift values to be >=0. Each integer still uniquely identifies a
|
||||
# relative position difference.
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ops
|
||||
|
||||
|
||||
class NetWorkClipByValue(nn.Cell):
|
||||
def construct(self, x, min_value, max_value):
|
||||
return ops.clip_by_value(x, min_value, max_value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_ops_clip_by_value_tensor(mode):
|
||||
"""
|
||||
Feature: ops.clip_by_value
|
||||
Description: Verify the result of clip_by_value
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32)
|
||||
net = NetWorkClipByValue()
|
||||
output = net(x, -0.3, 0.4)
|
||||
expect_output = [-0.3, 0.4, 0.2349, -0.3, 0.4]
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_clip_by_value_list_tensor(mode):
|
||||
"""
|
||||
Feature: ops.clip_by_value
|
||||
Description: Verify the result of clip_by_value
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x1 = Tensor(np.array([-0.5962, 0.4985, 0.2349, -0.4396, 0.4525]), ms.float32)
|
||||
x2 = Tensor(np.array([0.6035, 0.6959, 0.0150, -0.5766, 0.5432]), ms.float32)
|
||||
x3 = Tensor(np.array([0.7549, 0.1056, 0.3312, -0.4060, 0.9821]), ms.float32)
|
||||
net = NetWorkClipByValue()
|
||||
output = net([x1, x2, x3], -0.3, 0.4)
|
||||
expect_output = [[-0.3, 0.4, 0.2349, -0.3, 0.4],
|
||||
[0.4, 0.4, 0.0150, -0.3, 0.4],
|
||||
[0.4, 0.1056, 0.3312, -0.3, 0.4]
|
||||
]
|
||||
assert np.allclose(output[0].asnumpy(), expect_output[0])
|
||||
assert np.allclose(output[1].asnumpy(), expect_output[1])
|
||||
assert np.allclose(output[2].asnumpy(), expect_output[2])
|
|
@ -16,6 +16,7 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -23,7 +24,6 @@ from mindspore.ops import operations as P
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.nn import layer
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
num_one = Tensor(np.ones([1]), mstype.float32)
|
||||
|
@ -88,7 +88,7 @@ class LambGPUOrigin(nn.Cell):
|
|||
self.op_select(self.op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
|
||||
ones)
|
||||
tens = self.op_fill(self.op_dtype(trust_ratio), self.op_shape(trust_ratio), 10.0)
|
||||
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
|
||||
trust_ratio = ops.clip_by_value(trust_ratio, zeros, tens)
|
||||
update = next_mm / (self.op_sqrt(next_vv) + eps)
|
||||
|
||||
if decay_flag:
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_clip_func """
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import ops
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from tests.mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
from tests.mindspore_test_framework.pipeline.forward.verify_exception \
|
||||
import pipeline_for_verify_exception_for_case_by_case_config
|
||||
|
||||
context.set_context(mode=(context.GRAPH_MODE))
|
||||
|
||||
|
||||
|
||||
class NetWorkClipByValue(nn.Cell):
|
||||
__doc__ = ' NetWorkClipByValue definition '
|
||||
|
||||
def __init__(self):
|
||||
super(NetWorkClipByValue, self).__init__()
|
||||
self.clip_func = ops.clip_by_value
|
||||
|
||||
def construct(self, x, min_value, max_value):
|
||||
return self.clip_func(x, min_value, max_value)
|
||||
|
||||
|
||||
class NetWorkClipListByValue(nn.Cell):
|
||||
__doc__ = ' NetWorkClipListByValue definition '
|
||||
|
||||
def __init__(self):
|
||||
super(NetWorkClipListByValue, self).__init__()
|
||||
self.clip_func = ops.clip_by_value
|
||||
|
||||
def construct(self, x1, x2, x3, min_value, max_value):
|
||||
return self.clip_func([x1, x2, x3], min_value, max_value)
|
||||
|
||||
|
||||
test_case_clip_func = [
|
||||
('ClipbyValue_1', {
|
||||
'block': NetWorkClipByValue(),
|
||||
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
|
||||
2,
|
||||
8.0],
|
||||
'skip': ['backward']}),
|
||||
('ClipbyValue_2', {
|
||||
'block': NetWorkClipByValue(),
|
||||
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
|
||||
None,
|
||||
8],
|
||||
'skip': ['backward']}),
|
||||
('ClipbyValue_3', {
|
||||
'block': NetWorkClipByValue(),
|
||||
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
|
||||
Tensor(np.array([2]).astype(np.float32)),
|
||||
Tensor(np.array(8).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('ClipListbyValue_1', {
|
||||
'block': NetWorkClipListByValue(),
|
||||
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
|
||||
Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
|
||||
Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
|
||||
Tensor(np.array(2).astype(np.float32)),
|
||||
8],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
|
||||
test_cases_for_verify_exception = [
|
||||
('ClipByValueERR_1', {
|
||||
'block': (NetWorkClipByValue(), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
|
||||
Tensor(np.array([2, 3]).astype(np.float32)),
|
||||
8]}),
|
||||
('ClipByValueERR_2', {
|
||||
'block': (NetWorkClipByValue(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
|
||||
'2',
|
||||
8]}),
|
||||
('ClipByValueERR_3', {
|
||||
'block': (NetWorkClipByValue(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(0, 10, [2, 3, 4]).astype(np.float32)),
|
||||
2,
|
||||
'8']}),
|
||||
('ClipByValueERR_4', {
|
||||
'block': (NetWorkClipByValue(), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.ones([1, 3, 3]).astype(np.float32)),
|
||||
None,
|
||||
None]}),
|
||||
]
|
||||
|
||||
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
||||
def test_exec():
|
||||
"""
|
||||
Feature: test list of clip function
|
||||
Description: test case
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=(context.GRAPH_MODE))
|
||||
return functools.reduce(lambda x, y: x + y, [test_case_clip_func])
|
||||
|
||||
|
||||
@mindspore_test(pipeline_for_verify_exception_for_case_by_case_config)
|
||||
def test_check_exception():
|
||||
"""
|
||||
Feature: test list getitem exception
|
||||
Description: test list getitem exception
|
||||
Expectation: throw errors
|
||||
"""
|
||||
return test_cases_for_verify_exception
|
|
@ -24,6 +24,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -48,8 +49,8 @@ update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2,
|
|||
def _clip_grad(clip_type, clip_value, grad):
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
|
Loading…
Reference in New Issue