!1697 optimize clip_norm

Merge pull request !1697 from chenhaozhe/bert-optimization
This commit is contained in:
mindspore-ci-bot 2020-06-02 11:28:53 +08:00 committed by Gitee
commit e79df7c169
5 changed files with 33 additions and 7 deletions

View File

@ -588,7 +588,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
graph->set_output_null(is_trace_back); graph->set_output_null(is_trace_back);
AddParameterToGraphInputs(func_graph->parameters(), graph.get()); AddParameterToGraphInputs(func_graph->parameters(), graph.get());
MS_EXCEPTION_IF_NULL(context_); MS_EXCEPTION_IF_NULL(context_);
FuncGraphManagerPtr manager = context_->manager(); FuncGraphManagerPtr manager = MakeManager({graph});
if (manager) { if (manager) {
manager->AddFuncGraph(graph); manager->AddFuncGraph(graph);
graph->set_manager(manager); graph->set_manager(manager);

View File

@ -22,6 +22,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops.functional import identity from mindspore.ops.functional import identity
from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
@ -236,6 +237,13 @@ class Dense(Cell):
return str_info return str_info
@constexpr
def _is_equal_one(x):
if x is None:
return False
return bool(x.asnumpy().mean() == 1.0)
class ClipByNorm(Cell): class ClipByNorm(Cell):
r""" r"""
Clips tensor values to a maximum :math:`L_2`-norm. Clips tensor values to a maximum :math:`L_2`-norm.
@ -290,7 +298,10 @@ class ClipByNorm(Cell):
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
intermediate = x * clip_norm if _is_equal_one(clip_norm):
intermediate = x
else:
intermediate = x * clip_norm
max_norm = self.max_op(l2norm, clip_norm) max_norm = self.max_op(l2norm, clip_norm)
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
values_clip = self.reshape(values_clip, self.shape(x)) values_clip = self.reshape(values_clip, self.shape(x))

View File

@ -32,7 +32,6 @@ from .bert_model import BertModel
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0 GRADIENT_CLIP_VALUE = 1.0
_nn_clip_by_norm = nn.ClipByNorm()
clip_grad = C.MultitypeFuncGraph("clip_grad") clip_grad = C.MultitypeFuncGraph("clip_grad")
@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad):
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt)) F.cast(F.tuple_to_array((clip_value,)), dt))
else: else:
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad return new_grad

View File

@ -56,7 +56,7 @@ if cfg.bert_network == 'base':
bert_net_cfg = BertConfig( bert_net_cfg = BertConfig(
batch_size=32, batch_size=32,
seq_length=128, seq_length=128,
vocab_size=21136, vocab_size=21128,
hidden_size=768, hidden_size=768,
num_hidden_layers=12, num_hidden_layers=12,
num_attention_heads=12, num_attention_heads=12,
@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha':
bert_net_cfg = BertConfig( bert_net_cfg = BertConfig(
batch_size=32, batch_size=32,
seq_length=128, seq_length=128,
vocab_size=21136, vocab_size=21128,
hidden_size=1024, hidden_size=1024,
num_hidden_layers=24, num_hidden_layers=24,
num_attention_heads=16, num_attention_heads=16,
@ -98,7 +98,7 @@ if cfg.bert_network == 'large':
bert_net_cfg = BertConfig( bert_net_cfg = BertConfig(
batch_size=16, batch_size=16,
seq_length=512, seq_length=512,
vocab_size=30528, vocab_size=30522,
hidden_size=1024, hidden_size=1024,
num_hidden_layers=24, num_hidden_layers=24,
num_attention_heads=16, num_attention_heads=16,

View File

@ -26,3 +26,19 @@ def test_clip_by_norm():
x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32)) x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32))
clip_norm = Tensor(np.array([1]).astype(np.float32)) clip_norm = Tensor(np.array([1]).astype(np.float32))
clip_by_norm(x, clip_norm) clip_by_norm(x, clip_norm)
@non_graph_engine
def test_clip_by_norm_const():
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.norm_value = Tensor(np.array([1]).astype(np.float32))
self.clip = nn.ClipByNorm()
def construct(self, x):
return self.clip(x, self.norm_value)
net = Network()
x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32))
output = net(x)