forked from mindspore-Ecosystem/mindspore
add layer of clipbyglobalnorm.
This commit is contained in:
parent
ccbc6df79c
commit
4851a67bb5
|
@ -273,6 +273,7 @@ class ClipByNorm(Cell):
|
|||
\text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
|
||||
|
||||
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
|
||||
|
||||
Args:
|
||||
axis (Union[None, int, tuple(int)): Compute the L2-norm along the Specific dimension.
|
||||
Default: None, all dimensions to calculate.
|
||||
|
@ -280,6 +281,7 @@ class ClipByNorm(Cell):
|
|||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16.
|
||||
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
|
||||
Or a tensor shape can be broadcast to input shape.
|
||||
|
||||
Outputs:
|
||||
Tensor, clipped tensor with the same shape as the input, whose type is float32.
|
||||
|
|
|
@ -22,7 +22,7 @@ Pre-defined combination of operators.
|
|||
|
||||
from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \
|
||||
core, env_get, tail, zip_operation
|
||||
from .clip_ops import clip_by_value
|
||||
from .clip_ops import clip_by_value, clip_by_global_norm
|
||||
from .multitype_ops.add_impl import hyper_add
|
||||
from .multitype_ops.ones_like_impl import ones_like
|
||||
from .multitype_ops.zeros_like_impl import zeros_like
|
||||
|
@ -49,4 +49,5 @@ __all__ = [
|
|||
'poisson',
|
||||
'multinomial',
|
||||
'clip_by_value',
|
||||
'clip_by_global_norm',
|
||||
'count_nonzero']
|
||||
|
|
|
@ -14,8 +14,15 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Operations for clipping tensors to min/max values."""
|
||||
|
||||
from .. import operations as P
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
|
||||
def clip_by_value(x, clip_value_min, clip_value_max):
|
||||
|
@ -41,3 +48,89 @@ def clip_by_value(x, clip_value_min, clip_value_max):
|
|||
x_min = min_op(x, clip_value_max)
|
||||
x_max = max_op(x_min, clip_value_min)
|
||||
return x_max
|
||||
|
||||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
@get_square_sum.register("Tensor")
|
||||
def _get_square_sum(x):
|
||||
norm = P.ReduceSum(False)(F.square(x), ())
|
||||
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return norm
|
||||
|
||||
|
||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(clip_norm, global_norm, x):
|
||||
x = x * clip_norm / global_norm
|
||||
return x
|
||||
|
||||
|
||||
class _ClipByGlobalNorm(Cell):
|
||||
r"""
|
||||
Clips tensor values by the ratio of the sum of their norms.
|
||||
|
||||
Args:
|
||||
clip_norm (Union(float, int)): The clipping ratio. Default: 1.0
|
||||
use_norm (Union(float, None)): The global norm. Default: None
|
||||
|
||||
Inputs:
|
||||
- **x** (Union(tuple[Tensor], list[Tensor])) - Input data to clip.
|
||||
|
||||
Outputs:
|
||||
Tensor, a clipped Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, clip_norm=1.0, use_norm=None):
|
||||
super(_ClipByGlobalNorm, self).__init__()
|
||||
# Add interface. This parameter is not used at present
|
||||
if use_norm is not None:
|
||||
validator.check_number("use_norm", use_norm, 0.0, Rel.GE, self.cls_name)
|
||||
validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, self.cls_name)
|
||||
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.greater_equal = P.GreaterEqual()
|
||||
|
||||
def construct(self, x):
|
||||
square_sum = self.hyper_map(get_square_sum, x)
|
||||
global_norm = F.sqrt(F.addn(square_sum))
|
||||
cond = self.greater_equal(global_norm, self.clip_norm)
|
||||
global_norm = F.select(cond, global_norm, self.clip_norm)
|
||||
clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x)
|
||||
return clip_x
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_value(clip_norm):
|
||||
validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, "clip_by_global_norm")
|
||||
return clip_norm
|
||||
|
||||
|
||||
def clip_by_global_norm(x, clip_norm=1.0, use_norm=None):
|
||||
r"""
|
||||
Clips tensor values by the ratio of the sum of their norms.
|
||||
Note:
|
||||
'input x' should be a tuple or list of tensors. Otherwise, it will raise an error.
|
||||
|
||||
Args:
|
||||
x (Union(tuple[Tensor], list[Tensor])): Input data to clip.
|
||||
clip_norm (Union(float, int)): The clipping ratio. Default: 1.0
|
||||
use_norm (None): The global norm. Default: None. Currently only none is supported.
|
||||
|
||||
Returns:
|
||||
Tensor, a clipped Tensor.
|
||||
|
||||
Examples:
|
||||
>>> x1 = np.array([[2., 3.],[1., 2.]]).astype(np.float32)
|
||||
>>> x2 = np.array([[1., 4.],[3., 1.]]).astype(np.float32)
|
||||
>>> input_x = (Tensor(x1), Tensor(x2))
|
||||
>>> out = clip_by_global_norm(input_x, 1.0)
|
||||
([[ 2.98142403e-01, 4.47213590e-01],
|
||||
[ 1.49071202e-01, 2.98142403e-01]],
|
||||
|
||||
[[ 1.49071202e-01, 5.96284807e-01],
|
||||
[ 4.47213590e-01, 1.49071202e-01]])
|
||||
"""
|
||||
|
||||
clip_norm = _check_value(clip_norm)
|
||||
out = _ClipByGlobalNorm(clip_norm, use_norm)(x)
|
||||
return out
|
||||
|
|
|
@ -28,7 +28,6 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from .bert_model import BertModel
|
||||
from .utils import ClipByGlobalNorm
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
@ -565,7 +564,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
||||
if self.enable_global_norm:
|
||||
grads = ClipByGlobalNorm()(grads)
|
||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
accu_overflow = self.overflow_reducer(accu_overflow)
|
||||
|
|
|
@ -24,62 +24,12 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore import log as logger
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
|
||||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
@get_square_sum.register("Tensor")
|
||||
def _get_square_sum(grad):
|
||||
norm = P.ReduceSum(False)(F.square(grad), ())
|
||||
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return norm
|
||||
|
||||
|
||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(clip_norm, global_norm, grad):
|
||||
grad = grad * clip_norm / global_norm
|
||||
return grad
|
||||
|
||||
|
||||
class GlobalNorm(nn.Cell):
|
||||
"""
|
||||
Calculate the global norm value of given tensors
|
||||
"""
|
||||
def __init__(self):
|
||||
super(GlobalNorm, self).__init__()
|
||||
self.norm = nn.Norm()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
square_sum = self.hyper_map(get_square_sum, grads)
|
||||
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
|
||||
return global_norms
|
||||
|
||||
|
||||
class ClipByGlobalNorm(nn.Cell):
|
||||
"""
|
||||
Clip grads by global norm
|
||||
"""
|
||||
def __init__(self, clip_norm=1.0):
|
||||
super(ClipByGlobalNorm, self).__init__()
|
||||
self.global_norm = GlobalNorm()
|
||||
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
global_norm = self.global_norm(grads)
|
||||
cond = P.GreaterEqual()(global_norm, self.clip_norm)
|
||||
global_norm = F.select(cond, global_norm, self.clip_norm)
|
||||
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
||||
return grads
|
||||
|
||||
|
||||
class CrossEntropyCalculation(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
|
|
|
@ -240,6 +240,20 @@ class ClipByNorm(nn.Cell):
|
|||
return norm
|
||||
|
||||
|
||||
class ClipByGlobalNorm(nn.Cell):
|
||||
"""ClipByGlobalNorm net definition"""
|
||||
|
||||
def __init__(self, x, clip_norm=1.0, use_norm=None):
|
||||
super(ClipByGlobalNorm, self).__init__()
|
||||
self.x = x
|
||||
self.clip_norm = clip_norm
|
||||
self.use_norm = use_norm
|
||||
|
||||
def construct(self):
|
||||
norm = C.clip_by_global_norm(self.x, self.clip_norm, self.use_norm)
|
||||
return norm
|
||||
|
||||
|
||||
class Embedding(nn.Cell):
|
||||
"""Embedding net definition"""
|
||||
|
||||
|
@ -1130,6 +1144,11 @@ test_case_math_ops = [
|
|||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)),
|
||||
Tensor(np.array([0.01]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('ClipByGlobalNorm', {
|
||||
'block': ClipByGlobalNorm(x=Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)),
|
||||
clip_norm=1.0, use_norm=None),
|
||||
'desc_inputs': [],
|
||||
'skip': ['backward']}),
|
||||
('Embedding_1', {
|
||||
'block': Embedding(vocab_size=10, embedding_size=3),
|
||||
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
|
||||
|
|
Loading…
Reference in New Issue