forked from mindspore-Ecosystem/mindspore
!1331 delete dropoutgenmask and dropoutdomask when dropout prob equals 0 to enhance performance and adjust ci script
Merge pull request !1331 from yoonlee666/master-deletedropout
This commit is contained in:
commit
889696bcab
|
@ -42,7 +42,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
if enable_data_sink == "true":
|
||||
new_size = data_sink_steps * bert_net_cfg.batch_size
|
||||
ds.set_dataset_size(new_size)
|
||||
repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
|
||||
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
|
@ -55,4 +55,4 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
ds = ds.repeat(repeat_count)
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
return ds
|
||||
return ds, new_repeat_count
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore import context
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
|
||||
|
@ -87,8 +87,9 @@ def run_pretrain():
|
|||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink,
|
||||
args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir)
|
||||
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.enable_data_sink, args_opt.data_sink_steps,
|
||||
args_opt.data_dir, args_opt.schema_dir)
|
||||
|
||||
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||
|
||||
|
@ -112,7 +113,7 @@ def run_pretrain():
|
|||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
|
||||
format(cfg.optimizer))
|
||||
callback = [LossCallBack()]
|
||||
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
|
||||
if args_opt.enable_save_ckpt == "true":
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
|
@ -133,6 +134,6 @@ def run_pretrain():
|
|||
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
|
||||
|
||||
model = Model(netwithgrads)
|
||||
model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
|
||||
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
|
||||
if __name__ == '__main__':
|
||||
run_pretrain()
|
||||
|
|
|
@ -99,6 +99,9 @@ class Dropout(Cell):
|
|||
out, _ = self.dropout(x)
|
||||
return out
|
||||
|
||||
if self.keep_prob == 1:
|
||||
return x
|
||||
|
||||
shape = self.get_shape(x)
|
||||
dtype = P.DType()(x)
|
||||
keep_prob = self.cast(self.keep_prob, dtype)
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore import context
|
|||
from mindspore import log as logger
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.nn.optim import Lamb
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.train.model import Model
|
||||
|
@ -73,7 +73,7 @@ def get_config(version='base', batch_size=1):
|
|||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
|
@ -138,7 +138,9 @@ def test_bert_tdt():
|
|||
batch_size = int(os.getenv('BATCH_SIZE', '16'))
|
||||
config = get_config(version=version, batch_size=batch_size)
|
||||
netwithloss = BertNetworkWithLoss(config, True)
|
||||
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9)
|
||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(),
|
||||
start_learning_rate=5e-5, end_learning_rate=1e-9,
|
||||
power=10.0, warmup_steps=0, weight_decay=0.01)
|
||||
scale_window = 3
|
||||
scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window)
|
||||
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
||||
|
@ -169,10 +171,10 @@ def test_bert_tdt():
|
|||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
expect_loss_value = [12.191826, 11.966009, 11.972208, 11.98216, 11.973932, 12.611078, 12.17554, 12.840299,
|
||||
12.403329, 12.621632]
|
||||
expect_loss_value = [12.207201, 11.980862, 11.984737, 11.879344, 11.832838, 12.411388,
|
||||
12.009449, 12.621273, 12.223175, 12.427313]
|
||||
print("loss value: {}".format(loss_value))
|
||||
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)
|
||||
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
||||
|
||||
overflow = np.array(callback.overflow_list)
|
||||
expect_overflow = [True, True, False, False, False, True, False, False, False, True]
|
||||
|
@ -182,7 +184,7 @@ def test_bert_tdt():
|
|||
loss_scale = np.array(callback.lossscale_list)
|
||||
expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0]
|
||||
print("loss scale: {}".format(loss_scale))
|
||||
assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001)
|
||||
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue