forked from mindspore-Ecosystem/mindspore
!1697 optimize clip_norm
Merge pull request !1697 from chenhaozhe/bert-optimization
This commit is contained in:
commit
e79df7c169
|
@ -588,7 +588,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
graph->set_output_null(is_trace_back);
|
||||
AddParameterToGraphInputs(func_graph->parameters(), graph.get());
|
||||
MS_EXCEPTION_IF_NULL(context_);
|
||||
FuncGraphManagerPtr manager = context_->manager();
|
||||
FuncGraphManagerPtr manager = MakeManager({graph});
|
||||
if (manager) {
|
||||
manager->AddFuncGraph(graph);
|
||||
graph->set_manager(manager);
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.functional import identity
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.common.api import ms_function
|
||||
|
@ -236,6 +237,13 @@ class Dense(Cell):
|
|||
return str_info
|
||||
|
||||
|
||||
@constexpr
|
||||
def _is_equal_one(x):
|
||||
if x is None:
|
||||
return False
|
||||
return bool(x.asnumpy().mean() == 1.0)
|
||||
|
||||
|
||||
class ClipByNorm(Cell):
|
||||
r"""
|
||||
Clips tensor values to a maximum :math:`L_2`-norm.
|
||||
|
@ -290,7 +298,10 @@ class ClipByNorm(Cell):
|
|||
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
||||
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
|
||||
|
||||
intermediate = x * clip_norm
|
||||
if _is_equal_one(clip_norm):
|
||||
intermediate = x
|
||||
else:
|
||||
intermediate = x * clip_norm
|
||||
max_norm = self.max_op(l2norm, clip_norm)
|
||||
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
|
||||
values_clip = self.reshape(values_clip, self.shape(x))
|
||||
|
|
|
@ -32,7 +32,6 @@ from .bert_model import BertModel
|
|||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
_nn_clip_by_norm = nn.ClipByNorm()
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
|
@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ if cfg.bert_network == 'base':
|
|||
bert_net_cfg = BertConfig(
|
||||
batch_size=32,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
|
@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha':
|
|||
bert_net_cfg = BertConfig(
|
||||
batch_size=32,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
vocab_size=21128,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
|
@ -98,7 +98,7 @@ if cfg.bert_network == 'large':
|
|||
bert_net_cfg = BertConfig(
|
||||
batch_size=16,
|
||||
seq_length=512,
|
||||
vocab_size=30528,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
|
|
|
@ -26,3 +26,19 @@ def test_clip_by_norm():
|
|||
x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32))
|
||||
clip_norm = Tensor(np.array([1]).astype(np.float32))
|
||||
clip_by_norm(x, clip_norm)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_clip_by_norm_const():
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.norm_value = Tensor(np.array([1]).astype(np.float32))
|
||||
self.clip = nn.ClipByNorm()
|
||||
|
||||
def construct(self, x):
|
||||
return self.clip(x, self.norm_value)
|
||||
|
||||
net = Network()
|
||||
x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32))
|
||||
output = net(x)
|
||||
|
|
Loading…
Reference in New Issue