From 435fc12e28798bed4618e12ff121e1812bdced30 Mon Sep 17 00:00:00 2001 From: chenhaozhe Date: Fri, 29 May 2020 20:06:16 +0800 Subject: [PATCH] optimize clip_norm --- mindspore/ccsrc/session/session_basic.cc | 2 +- mindspore/nn/layer/basic.py | 13 ++++++++++++- model_zoo/bert/src/bert_for_pre_training.py | 3 +-- model_zoo/bert/src/config.py | 6 +++--- tests/ut/python/nn/test_clip_by_norm.py | 16 ++++++++++++++++ 5 files changed, 33 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index b1bfefcac1a..f0f4fe419e1 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -588,7 +588,7 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP graph->set_output_null(is_trace_back); AddParameterToGraphInputs(func_graph->parameters(), graph.get()); MS_EXCEPTION_IF_NULL(context_); - FuncGraphManagerPtr manager = context_->manager(); + FuncGraphManagerPtr manager = MakeManager({graph}); if (manager) { manager->AddFuncGraph(graph); graph->set_manager(manager); diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 8f4e468e0b6..115f72b319e 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -22,6 +22,7 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops.functional import identity from mindspore.ops.operations import _inner_ops as inner +from mindspore.ops.primitive import constexpr from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from mindspore.common.api import ms_function @@ -236,6 +237,13 @@ class Dense(Cell): return str_info +@constexpr +def _is_equal_one(x): + if x is None: + return False + return bool(x.asnumpy().mean() == 1.0) + + class ClipByNorm(Cell): r""" Clips tensor values to a maximum :math:`L_2`-norm. @@ -290,7 +298,10 @@ class ClipByNorm(Cell): l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) - intermediate = x * clip_norm + if _is_equal_one(clip_norm): + intermediate = x + else: + intermediate = x * clip_norm max_norm = self.max_op(l2norm, clip_norm) values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) values_clip = self.reshape(values_clip, self.shape(x)) diff --git a/model_zoo/bert/src/bert_for_pre_training.py b/model_zoo/bert/src/bert_for_pre_training.py index 600512b4a77..976f1a3c43f 100644 --- a/model_zoo/bert/src/bert_for_pre_training.py +++ b/model_zoo/bert/src/bert_for_pre_training.py @@ -32,7 +32,6 @@ from .bert_model import BertModel GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_VALUE = 1.0 -_nn_clip_by_norm = nn.ClipByNorm() clip_grad = C.MultitypeFuncGraph("clip_grad") @@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad): new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), F.cast(F.tuple_to_array((clip_value,)), dt)) else: - new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) return new_grad diff --git a/model_zoo/bert/src/config.py b/model_zoo/bert/src/config.py index d1062b78eec..812f0c2f180 100644 --- a/model_zoo/bert/src/config.py +++ b/model_zoo/bert/src/config.py @@ -56,7 +56,7 @@ if cfg.bert_network == 'base': bert_net_cfg = BertConfig( batch_size=32, seq_length=128, - vocab_size=21136, + vocab_size=21128, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, @@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha': bert_net_cfg = BertConfig( batch_size=32, seq_length=128, - vocab_size=21136, + vocab_size=21128, hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, @@ -98,7 +98,7 @@ if cfg.bert_network == 'large': bert_net_cfg = BertConfig( batch_size=16, seq_length=512, - vocab_size=30528, + vocab_size=30522, hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, diff --git a/tests/ut/python/nn/test_clip_by_norm.py b/tests/ut/python/nn/test_clip_by_norm.py index ff7d1281081..813fada90a5 100644 --- a/tests/ut/python/nn/test_clip_by_norm.py +++ b/tests/ut/python/nn/test_clip_by_norm.py @@ -26,3 +26,19 @@ def test_clip_by_norm(): x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32)) clip_norm = Tensor(np.array([1]).astype(np.float32)) clip_by_norm(x, clip_norm) + + +@non_graph_engine +def test_clip_by_norm_const(): + class Network(nn.Cell): + def __init__(self): + super(Network, self).__init__() + self.norm_value = Tensor(np.array([1]).astype(np.float32)) + self.clip = nn.ClipByNorm() + + def construct(self, x): + return self.clip(x, self.norm_value) + + net = Network() + x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32)) + output = net(x)