forked from mindspore-Ecosystem/mindspore
!1263 Modified clip_gradients to clip_grad
Merge pull request !1263 from Kang/master
This commit is contained in:
commit
89470cf29d
|
@ -30,7 +30,7 @@ from mindspore.train.parallel_utils import ParallelMode
|
||||||
from mindspore.communication.management import get_group_size
|
from mindspore.communication.management import get_group_size
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
|
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
|
||||||
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients
|
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
|
||||||
from CRF import CRF
|
from CRF import CRF
|
||||||
|
|
||||||
GRADIENT_CLIP_TYPE = 1
|
GRADIENT_CLIP_TYPE = 1
|
||||||
|
@ -66,7 +66,6 @@ class BertFinetuneCell(nn.Cell):
|
||||||
degree = get_group_size()
|
degree = get_group_size()
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
self.clip_gradients = ClipGradients()
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.alloc_status = P.NPUAllocFloatStatus()
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
self.get_status = P.NPUGetFloatStatus()
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
|
@ -110,7 +109,7 @@ class BertFinetuneCell(nn.Cell):
|
||||||
F.control_depend(loss, init)
|
F.control_depend(loss, init)
|
||||||
self.depend_parameter_use(clear_before_grad, scaling_sens)
|
self.depend_parameter_use(clear_before_grad, scaling_sens)
|
||||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
flag = self.get_status(init)
|
flag = self.get_status(init)
|
||||||
|
|
|
@ -32,44 +32,31 @@ from .bert_model import BertModel
|
||||||
GRADIENT_CLIP_TYPE = 1
|
GRADIENT_CLIP_TYPE = 1
|
||||||
GRADIENT_CLIP_VALUE = 1.0
|
GRADIENT_CLIP_VALUE = 1.0
|
||||||
|
|
||||||
|
_nn_clip_by_norm = nn.ClipByNorm()
|
||||||
|
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||||
|
@clip_grad.register("Number", "Number", "Tensor")
|
||||||
|
|
||||||
class ClipGradients(nn.Cell):
|
def _clip_grad(clip_type, clip_value, grad):
|
||||||
"""
|
"""
|
||||||
Clip gradients.
|
Clip gradients.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
grads (tuple[Tensor]): Gradients.
|
|
||||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||||
clip_value (float): Specifies how much to clip.
|
clip_value (float): Specifies how much to clip.
|
||||||
|
grad (tuple[Tensor]): Gradients.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
tuple[Tensor], clipped gradients.
|
tuple[Tensor], clipped gradients.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
if clip_type != 0 and clip_type != 1:
|
||||||
super(ClipGradients, self).__init__()
|
return grad
|
||||||
self.clip_by_norm = nn.ClipByNorm()
|
dt = F.dtype(grad)
|
||||||
self.cast = P.Cast()
|
if clip_type == 0:
|
||||||
self.dtype = P.DType()
|
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||||
|
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
def construct(self,
|
else:
|
||||||
grads,
|
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
clip_type,
|
return new_grad
|
||||||
clip_value):
|
|
||||||
if clip_type != 0 and clip_type != 1:
|
|
||||||
return grads
|
|
||||||
|
|
||||||
new_grads = ()
|
|
||||||
for grad in grads:
|
|
||||||
dt = self.dtype(grad)
|
|
||||||
if clip_type == 0:
|
|
||||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
|
||||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
|
||||||
else:
|
|
||||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
|
||||||
new_grads = new_grads + (t,)
|
|
||||||
|
|
||||||
return new_grads
|
|
||||||
|
|
||||||
|
|
||||||
class GetMaskedLMOutput(nn.Cell):
|
class GetMaskedLMOutput(nn.Cell):
|
||||||
"""
|
"""
|
||||||
|
@ -294,8 +281,8 @@ class BertTrainOneStepCell(nn.Cell):
|
||||||
degree = get_group_size()
|
degree = get_group_size()
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
|
|
||||||
self.clip_gradients = ClipGradients()
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
||||||
def set_sens(self, value):
|
def set_sens(self, value):
|
||||||
self.sens = value
|
self.sens = value
|
||||||
|
@ -327,7 +314,7 @@ class BertTrainOneStepCell(nn.Cell):
|
||||||
masked_lm_weights,
|
masked_lm_weights,
|
||||||
self.cast(F.tuple_to_array((self.sens,)),
|
self.cast(F.tuple_to_array((self.sens,)),
|
||||||
mstype.float32))
|
mstype.float32))
|
||||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
# apply grad reducer on grads
|
# apply grad reducer on grads
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
|
@ -376,7 +363,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
degree = get_group_size()
|
degree = get_group_size()
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
self.clip_gradients = ClipGradients()
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.alloc_status = P.NPUAllocFloatStatus()
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
self.get_status = P.NPUGetFloatStatus()
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
|
@ -427,7 +413,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
self.cast(scaling_sens,
|
self.cast(scaling_sens,
|
||||||
mstype.float32))
|
mstype.float32))
|
||||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
# apply grad reducer on grads
|
# apply grad reducer on grads
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
self.get_status(init)
|
self.get_status(init)
|
||||||
|
|
|
@ -19,11 +19,12 @@ import numpy as np
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.ops.composite as C
|
import mindspore.ops.composite as C
|
||||||
|
from mindspore.ops import functional as F
|
||||||
from mindspore.common.initializer import TruncatedNormal
|
from mindspore.common.initializer import TruncatedNormal
|
||||||
from mindspore.common.parameter import ParameterTuple
|
from mindspore.common.parameter import ParameterTuple
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput
|
from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput
|
||||||
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients
|
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
|
||||||
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \
|
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \
|
||||||
EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \
|
EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \
|
||||||
RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \
|
RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \
|
||||||
|
@ -80,12 +81,12 @@ class TrainStepWrapForAdam(nn.Cell):
|
||||||
self.network = network
|
self.network = network
|
||||||
self.weights = ParameterTuple(network.get_parameters())
|
self.weights = ParameterTuple(network.get_parameters())
|
||||||
self.optimizer = AdamWeightDecay(self.weights)
|
self.optimizer = AdamWeightDecay(self.weights)
|
||||||
self.clip_gradients = ClipGradients()
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
||||||
def construct(self, x, sens):
|
def construct(self, x, sens):
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
grads = C.grad_by_list_with_sens(self.network, weights)(x, sens)
|
grads = C.grad_by_list_with_sens(self.network, weights)(x, sens)
|
||||||
grads = self.clip_gradients(grads, 1, 1.0)
|
grads = self.hyper_map(F.partial(clip_grad, 1, 1.0), grads)
|
||||||
return self.optimizer(grads)
|
return self.optimizer(grads)
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,9 +112,10 @@ class TempC2Wrap(nn.Cell):
|
||||||
self.op = op
|
self.op = op
|
||||||
self.c1 = c1
|
self.c1 = c1
|
||||||
self.c2 = c2
|
self.c2 = c2
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
||||||
def construct(self, x1):
|
def construct(self, x1):
|
||||||
x = self.op(x1, self.c1, self.c2)
|
x = self.hyper_map(F.partial(self.op, self.c1, self.c2), x1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -405,7 +407,7 @@ test_case_cell_ops = [
|
||||||
'desc_inputs': [[1, 64]],
|
'desc_inputs': [[1, 64]],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
('ClipGradients', {
|
('ClipGradients', {
|
||||||
'block': TempC2Wrap(ClipGradients(), 1, 1.0),
|
'block': TempC2Wrap(clip_grad, 1, 1.0),
|
||||||
'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])],
|
'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])],
|
||||||
'skip': ['backward', 'exec']}),
|
'skip': ['backward', 'exec']}),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue