!1263 Modified clip_gradients to clip_grad

Merge pull request !1263 from Kang/master
This commit is contained in:
mindspore-ci-bot 2020-05-20 14:15:05 +08:00 committed by Gitee
commit 89470cf29d
3 changed files with 26 additions and 39 deletions

View File

@ -30,7 +30,7 @@ from mindspore.train.parallel_utils import ParallelMode
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore import context from mindspore import context
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
from CRF import CRF from CRF import CRF
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
@ -66,7 +66,6 @@ class BertFinetuneCell(nn.Cell):
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()
self.cast = P.Cast() self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus() self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus() self.get_status = P.NPUGetFloatStatus()
@ -110,7 +109,7 @@ class BertFinetuneCell(nn.Cell):
F.control_depend(loss, init) F.control_depend(loss, init)
self.depend_parameter_use(clear_before_grad, scaling_sens) self.depend_parameter_use(clear_before_grad, scaling_sens)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag: if self.reducer_flag:
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
flag = self.get_status(init) flag = self.get_status(init)

View File

@ -32,44 +32,31 @@ from .bert_model import BertModel
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0 GRADIENT_CLIP_VALUE = 1.0
_nn_clip_by_norm = nn.ClipByNorm()
clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
class ClipGradients(nn.Cell): def _clip_grad(clip_type, clip_value, grad):
""" """
Clip gradients. Clip gradients.
Inputs: Inputs:
grads (tuple[Tensor]): Gradients.
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip. clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs: Outputs:
tuple[Tensor], clipped gradients. tuple[Tensor], clipped gradients.
""" """
def __init__(self): if clip_type != 0 and clip_type != 1:
super(ClipGradients, self).__init__() return grad
self.clip_by_norm = nn.ClipByNorm() dt = F.dtype(grad)
self.cast = P.Cast() if clip_type == 0:
self.dtype = P.DType() new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
def construct(self, else:
grads, new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
clip_type, return new_grad
clip_value):
if clip_type != 0 and clip_type != 1:
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
new_grads = new_grads + (t,)
return new_grads
class GetMaskedLMOutput(nn.Cell): class GetMaskedLMOutput(nn.Cell):
""" """
@ -294,8 +281,8 @@ class BertTrainOneStepCell(nn.Cell):
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.clip_gradients = ClipGradients()
self.cast = P.Cast() self.cast = P.Cast()
self.hyper_map = C.HyperMap()
def set_sens(self, value): def set_sens(self, value):
self.sens = value self.sens = value
@ -327,7 +314,7 @@ class BertTrainOneStepCell(nn.Cell):
masked_lm_weights, masked_lm_weights,
self.cast(F.tuple_to_array((self.sens,)), self.cast(F.tuple_to_array((self.sens,)),
mstype.float32)) mstype.float32))
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag: if self.reducer_flag:
# apply grad reducer on grads # apply grad reducer on grads
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
@ -376,7 +363,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()
self.cast = P.Cast() self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus() self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus() self.get_status = P.NPUGetFloatStatus()
@ -427,7 +413,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
self.cast(scaling_sens, self.cast(scaling_sens,
mstype.float32)) mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
# apply grad reducer on grads # apply grad reducer on grads
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
self.get_status(init) self.get_status(init)

View File

@ -19,11 +19,12 @@ import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.composite as C import mindspore.ops.composite as C
from mindspore.ops import functional as F
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
from mindspore.common.parameter import ParameterTuple from mindspore.common.parameter import ParameterTuple
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \ from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \
EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \ EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \
RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \ RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \
@ -80,12 +81,12 @@ class TrainStepWrapForAdam(nn.Cell):
self.network = network self.network = network
self.weights = ParameterTuple(network.get_parameters()) self.weights = ParameterTuple(network.get_parameters())
self.optimizer = AdamWeightDecay(self.weights) self.optimizer = AdamWeightDecay(self.weights)
self.clip_gradients = ClipGradients() self.hyper_map = C.HyperMap()
def construct(self, x, sens): def construct(self, x, sens):
weights = self.weights weights = self.weights
grads = C.grad_by_list_with_sens(self.network, weights)(x, sens) grads = C.grad_by_list_with_sens(self.network, weights)(x, sens)
grads = self.clip_gradients(grads, 1, 1.0) grads = self.hyper_map(F.partial(clip_grad, 1, 1.0), grads)
return self.optimizer(grads) return self.optimizer(grads)
@ -111,9 +112,10 @@ class TempC2Wrap(nn.Cell):
self.op = op self.op = op
self.c1 = c1 self.c1 = c1
self.c2 = c2 self.c2 = c2
self.hyper_map = C.HyperMap()
def construct(self, x1): def construct(self, x1):
x = self.op(x1, self.c1, self.c2) x = self.hyper_map(F.partial(self.op, self.c1, self.c2), x1)
return x return x
@ -405,7 +407,7 @@ test_case_cell_ops = [
'desc_inputs': [[1, 64]], 'desc_inputs': [[1, 64]],
'skip': ['backward']}), 'skip': ['backward']}),
('ClipGradients', { ('ClipGradients', {
'block': TempC2Wrap(ClipGradients(), 1, 1.0), 'block': TempC2Wrap(clip_grad, 1, 1.0),
'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])], 'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])],
'skip': ['backward', 'exec']}), 'skip': ['backward', 'exec']}),
] ]