forked from mindspore-Ecosystem/mindspore
Modified clip_gradients to clip_grad
This commit is contained in:
parent
2a1aad0f55
commit
bfff7c0a2f
|
@ -30,7 +30,7 @@ from mindspore.train.parallel_utils import ParallelMode
|
|||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
|
||||
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients
|
||||
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
|
||||
from CRF import CRF
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
|
@ -66,7 +66,6 @@ class BertFinetuneCell(nn.Cell):
|
|||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
|
@ -110,7 +109,7 @@ class BertFinetuneCell(nn.Cell):
|
|||
F.control_depend(loss, init)
|
||||
self.depend_parameter_use(clear_before_grad, scaling_sens)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
flag = self.get_status(init)
|
||||
|
|
|
@ -32,44 +32,31 @@ from .bert_model import BertModel
|
|||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
_nn_clip_by_norm = nn.ClipByNorm()
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
grads (tuple[Tensor]): Gradients.
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self,
|
||||
grads,
|
||||
clip_type,
|
||||
clip_value):
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grads
|
||||
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grads = new_grads + (t,)
|
||||
|
||||
return new_grads
|
||||
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
class GetMaskedLMOutput(nn.Cell):
|
||||
"""
|
||||
|
@ -294,8 +281,8 @@ class BertTrainOneStepCell(nn.Cell):
|
|||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def set_sens(self, value):
|
||||
self.sens = value
|
||||
|
@ -327,7 +314,7 @@ class BertTrainOneStepCell(nn.Cell):
|
|||
masked_lm_weights,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
@ -376,7 +363,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
|
@ -427,7 +413,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
self.get_status(init)
|
||||
|
|
|
@ -19,11 +19,12 @@ import numpy as np
|
|||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput
|
||||
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients
|
||||
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
|
||||
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \
|
||||
EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \
|
||||
RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \
|
||||
|
@ -80,12 +81,12 @@ class TrainStepWrapForAdam(nn.Cell):
|
|||
self.network = network
|
||||
self.weights = ParameterTuple(network.get_parameters())
|
||||
self.optimizer = AdamWeightDecay(self.weights)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, x, sens):
|
||||
weights = self.weights
|
||||
grads = C.grad_by_list_with_sens(self.network, weights)(x, sens)
|
||||
grads = self.clip_gradients(grads, 1, 1.0)
|
||||
grads = self.hyper_map(F.partial(clip_grad, 1, 1.0), grads)
|
||||
return self.optimizer(grads)
|
||||
|
||||
|
||||
|
@ -111,9 +112,10 @@ class TempC2Wrap(nn.Cell):
|
|||
self.op = op
|
||||
self.c1 = c1
|
||||
self.c2 = c2
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, x1):
|
||||
x = self.op(x1, self.c1, self.c2)
|
||||
x = self.hyper_map(F.partial(self.op, self.c1, self.c2), x1)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -405,7 +407,7 @@ test_case_cell_ops = [
|
|||
'desc_inputs': [[1, 64]],
|
||||
'skip': ['backward']}),
|
||||
('ClipGradients', {
|
||||
'block': TempC2Wrap(ClipGradients(), 1, 1.0),
|
||||
'block': TempC2Wrap(clip_grad, 1, 1.0),
|
||||
'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])],
|
||||
'skip': ['backward', 'exec']}),
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue