forked from mindspore-Ecosystem/mindspore
!1697 optimize clip_norm
Merge pull request !1697 from chenhaozhe/bert-optimization
This commit is contained in:
commit
e79df7c169
|
@ -588,7 +588,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
||||||
graph->set_output_null(is_trace_back);
|
graph->set_output_null(is_trace_back);
|
||||||
AddParameterToGraphInputs(func_graph->parameters(), graph.get());
|
AddParameterToGraphInputs(func_graph->parameters(), graph.get());
|
||||||
MS_EXCEPTION_IF_NULL(context_);
|
MS_EXCEPTION_IF_NULL(context_);
|
||||||
FuncGraphManagerPtr manager = context_->manager();
|
FuncGraphManagerPtr manager = MakeManager({graph});
|
||||||
if (manager) {
|
if (manager) {
|
||||||
manager->AddFuncGraph(graph);
|
manager->AddFuncGraph(graph);
|
||||||
graph->set_manager(manager);
|
graph->set_manager(manager);
|
||||||
|
|
|
@ -22,6 +22,7 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops.functional import identity
|
from mindspore.ops.functional import identity
|
||||||
from mindspore.ops.operations import _inner_ops as inner
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore.common.api import ms_function
|
from mindspore.common.api import ms_function
|
||||||
|
@ -236,6 +237,13 @@ class Dense(Cell):
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _is_equal_one(x):
|
||||||
|
if x is None:
|
||||||
|
return False
|
||||||
|
return bool(x.asnumpy().mean() == 1.0)
|
||||||
|
|
||||||
|
|
||||||
class ClipByNorm(Cell):
|
class ClipByNorm(Cell):
|
||||||
r"""
|
r"""
|
||||||
Clips tensor values to a maximum :math:`L_2`-norm.
|
Clips tensor values to a maximum :math:`L_2`-norm.
|
||||||
|
@ -290,7 +298,10 @@ class ClipByNorm(Cell):
|
||||||
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
||||||
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
|
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
|
||||||
|
|
||||||
intermediate = x * clip_norm
|
if _is_equal_one(clip_norm):
|
||||||
|
intermediate = x
|
||||||
|
else:
|
||||||
|
intermediate = x * clip_norm
|
||||||
max_norm = self.max_op(l2norm, clip_norm)
|
max_norm = self.max_op(l2norm, clip_norm)
|
||||||
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
|
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
|
||||||
values_clip = self.reshape(values_clip, self.shape(x))
|
values_clip = self.reshape(values_clip, self.shape(x))
|
||||||
|
|
|
@ -32,7 +32,6 @@ from .bert_model import BertModel
|
||||||
GRADIENT_CLIP_TYPE = 1
|
GRADIENT_CLIP_TYPE = 1
|
||||||
GRADIENT_CLIP_VALUE = 1.0
|
GRADIENT_CLIP_VALUE = 1.0
|
||||||
|
|
||||||
_nn_clip_by_norm = nn.ClipByNorm()
|
|
||||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad):
|
||||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
else:
|
else:
|
||||||
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
return new_grad
|
return new_grad
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ if cfg.bert_network == 'base':
|
||||||
bert_net_cfg = BertConfig(
|
bert_net_cfg = BertConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
seq_length=128,
|
seq_length=128,
|
||||||
vocab_size=21136,
|
vocab_size=21128,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
num_hidden_layers=12,
|
num_hidden_layers=12,
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
|
@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha':
|
||||||
bert_net_cfg = BertConfig(
|
bert_net_cfg = BertConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
seq_length=128,
|
seq_length=128,
|
||||||
vocab_size=21136,
|
vocab_size=21128,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
num_hidden_layers=24,
|
num_hidden_layers=24,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
|
@ -98,7 +98,7 @@ if cfg.bert_network == 'large':
|
||||||
bert_net_cfg = BertConfig(
|
bert_net_cfg = BertConfig(
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
seq_length=512,
|
seq_length=512,
|
||||||
vocab_size=30528,
|
vocab_size=30522,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
num_hidden_layers=24,
|
num_hidden_layers=24,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
|
|
|
@ -26,3 +26,19 @@ def test_clip_by_norm():
|
||||||
x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32))
|
x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32))
|
||||||
clip_norm = Tensor(np.array([1]).astype(np.float32))
|
clip_norm = Tensor(np.array([1]).astype(np.float32))
|
||||||
clip_by_norm(x, clip_norm)
|
clip_by_norm(x, clip_norm)
|
||||||
|
|
||||||
|
|
||||||
|
@non_graph_engine
|
||||||
|
def test_clip_by_norm_const():
|
||||||
|
class Network(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Network, self).__init__()
|
||||||
|
self.norm_value = Tensor(np.array([1]).astype(np.float32))
|
||||||
|
self.clip = nn.ClipByNorm()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.clip(x, self.norm_value)
|
||||||
|
|
||||||
|
net = Network()
|
||||||
|
x = Tensor(np.array([[-2, 0, 0], [0, 3, 4]]).astype(np.float32))
|
||||||
|
output = net(x)
|
||||||
|
|
Loading…
Reference in New Issue