add bert ci script

This commit is contained in:
yoonlee666 2020-07-27 15:08:03 +08:00
parent dfab48d532
commit 1dcf9abf6a
1 changed files with 30 additions and 17 deletions

View File

@ -28,6 +28,7 @@ import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore import context
from mindspore import log as logger
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.nn.optim import Lamb
from mindspore.train.callback import Callback
@ -100,12 +101,8 @@ def me_de_train_dataset(sink_mode=False):
type_cast_op = C.TypeCast(mstype.int32)
new_repeat_count = repeat_count
if sink_mode:
repeat_count = 30
sink_size = 100
ori_dataaet_size = ds.get_dataset_size()
new_size = sink_size * batch_size
ds.set_dataset_size(new_size)
new_repeat_count = int(repeat_count * ori_dataaet_size // ds.get_dataset_size())
new_repeat_count = 3
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
@ -129,7 +126,10 @@ def weight_variable(shape):
class BertLearningRate(lr_schedules.LearningRateSchedule):
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
@ -138,10 +138,13 @@ class BertLearningRate(lr_schedules.LearningRateSchedule):
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr
@ -174,6 +177,10 @@ class TimeMonitor(Callback):
self.epoch_mseconds_list.append(epoch_mseconds)
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bert_percision():
"""test bert percision"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
@ -187,10 +194,11 @@ def test_bert_percision():
power=10.0, warmup_steps=0)
decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()
no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower()
decay_params = list(filter(decay_filter, net_with_loss.trainable_params()))
other_params = list(filter(no_decay_filter, net_with_loss.trainable_params()))
decay_params = list(filter(decay_filter, netwithloss.trainable_params()))
other_params = list(filter(no_decay_filter, netwithloss.trainable_params()))
group_params = [{'params': decay_params, 'weight_decay': 0.01},
{'params': other_params}]
{'params': other_params},
{'order_params': netwithloss.trainable_params()}]
optimizer = Lamb(group_params, lr)
scale_window = 3
scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window)
@ -239,6 +247,10 @@ def test_bert_percision():
print("loss scale: {}".format(loss_scale))
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bert_performance():
"""test bert performance"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
@ -248,15 +260,16 @@ def test_bert_performance():
config = get_config(version=version, batch_size=batch_size)
netwithloss = BertNetworkWithLoss(config, True)
lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count,
lr = BertLearningRate(decay_steps=sink_size * new_repeat_count,
learning_rate=5e-5, end_learning_rate=1e-9,
power=10.0, warmup_steps=0)
decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()
no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower()
decay_params = list(filter(decay_filter, net_with_loss.trainable_params()))
other_params = list(filter(no_decay_filter, net_with_loss.trainable_params()))
decay_params = list(filter(decay_filter, netwithloss.trainable_params()))
other_params = list(filter(no_decay_filter, netwithloss.trainable_params()))
group_params = [{'params': decay_params, 'weight_decay': 0.01},
{'params': other_params}]
{'params': other_params},
{'order_params': netwithloss.trainable_params()}]
optimizer = Lamb(group_params, lr)
scale_window = 3
@ -285,7 +298,7 @@ def test_bert_performance():
else:
logger.info("***************** BERT param name is 3 {}".format(name))
param.default_input = weight_variable(value.asnumpy().shape)
time_monitor_callback = TimeMonitor(ds.get_dataset_size())
time_monitor_callback = TimeMonitor(sink_size)
model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback],
dataset_sink_mode=True, sink_size=sink_size)