diff --git a/example/bert_clue/utils.py b/example/bert_clue/utils.py index 6bf0c87a67d..1d05b957404 100644 --- a/example/bert_clue/utils.py +++ b/example/bert_clue/utils.py @@ -30,7 +30,7 @@ from mindspore.train.parallel_utils import ParallelMode from mindspore.communication.management import get_group_size from mindspore import context from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel -from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients +from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad from CRF import CRF GRADIENT_CLIP_TYPE = 1 @@ -66,7 +66,6 @@ class BertFinetuneCell(nn.Cell): degree = get_group_size() self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.clip_gradients = ClipGradients() self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() @@ -110,7 +109,7 @@ class BertFinetuneCell(nn.Cell): F.control_depend(loss, init) self.depend_parameter_use(clear_before_grad, scaling_sens) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) - grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) flag = self.get_status(init) diff --git a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py b/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py index c324f10f6b5..a5507fa782b 100644 --- a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py +++ b/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py @@ -32,44 +32,31 @@ from .bert_model import BertModel GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_VALUE = 1.0 +_nn_clip_by_norm = nn.ClipByNorm() +clip_grad = C.MultitypeFuncGraph("clip_grad") +@clip_grad.register("Number", "Number", "Tensor") -class ClipGradients(nn.Cell): +def _clip_grad(clip_type, clip_value, grad): """ Clip gradients. Inputs: - grads (tuple[Tensor]): Gradients. clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. Outputs: tuple[Tensor], clipped gradients. """ - def __init__(self): - super(ClipGradients, self).__init__() - self.clip_by_norm = nn.ClipByNorm() - self.cast = P.Cast() - self.dtype = P.DType() - - def construct(self, - grads, - clip_type, - clip_value): - if clip_type != 0 and clip_type != 1: - return grads - - new_grads = () - for grad in grads: - dt = self.dtype(grad) - if clip_type == 0: - t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt), - self.cast(F.tuple_to_array((clip_value,)), dt)) - else: - t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt)) - new_grads = new_grads + (t,) - - return new_grads - + if clip_type != 0 and clip_type != 1: + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad class GetMaskedLMOutput(nn.Cell): """ @@ -294,8 +281,8 @@ class BertTrainOneStepCell(nn.Cell): degree = get_group_size() self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.clip_gradients = ClipGradients() self.cast = P.Cast() + self.hyper_map = C.HyperMap() def set_sens(self, value): self.sens = value @@ -327,7 +314,7 @@ class BertTrainOneStepCell(nn.Cell): masked_lm_weights, self.cast(F.tuple_to_array((self.sens,)), mstype.float32)) - grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) @@ -376,7 +363,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): degree = get_group_size() self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.clip_gradients = ClipGradients() self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() @@ -427,7 +413,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): self.cast(scaling_sens, mstype.float32)) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) - grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) # apply grad reducer on grads grads = self.grad_reducer(grads) self.get_status(init) diff --git a/tests/ut/python/model/test_bert_cell.py b/tests/ut/python/model/test_bert_cell.py index 3de6073787d..817a75b6b49 100644 --- a/tests/ut/python/model/test_bert_cell.py +++ b/tests/ut/python/model/test_bert_cell.py @@ -19,11 +19,12 @@ import numpy as np import mindspore.common.dtype as mstype import mindspore.nn as nn import mindspore.ops.composite as C +from mindspore.ops import functional as F from mindspore.common.initializer import TruncatedNormal from mindspore.common.parameter import ParameterTuple from mindspore.common.tensor import Tensor from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput -from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients +from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \ EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \ RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \ @@ -80,12 +81,12 @@ class TrainStepWrapForAdam(nn.Cell): self.network = network self.weights = ParameterTuple(network.get_parameters()) self.optimizer = AdamWeightDecay(self.weights) - self.clip_gradients = ClipGradients() + self.hyper_map = C.HyperMap() def construct(self, x, sens): weights = self.weights grads = C.grad_by_list_with_sens(self.network, weights)(x, sens) - grads = self.clip_gradients(grads, 1, 1.0) + grads = self.hyper_map(F.partial(clip_grad, 1, 1.0), grads) return self.optimizer(grads) @@ -111,9 +112,10 @@ class TempC2Wrap(nn.Cell): self.op = op self.c1 = c1 self.c2 = c2 + self.hyper_map = C.HyperMap() def construct(self, x1): - x = self.op(x1, self.c1, self.c2) + x = self.hyper_map(F.partial(self.op, self.c1, self.c2), x1) return x @@ -405,7 +407,7 @@ test_case_cell_ops = [ 'desc_inputs': [[1, 64]], 'skip': ['backward']}), ('ClipGradients', { - 'block': TempC2Wrap(ClipGradients(), 1, 1.0), + 'block': TempC2Wrap(clip_grad, 1, 1.0), 'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])], 'skip': ['backward', 'exec']}), ]