diff --git a/model_zoo/official/nlp/transformer/train.py b/model_zoo/official/nlp/transformer/train.py index 21049f8644e..9285a6fb3bd 100644 --- a/model_zoo/official/nlp/transformer/train.py +++ b/model_zoo/official/nlp/transformer/train.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,19 +34,23 @@ from mindspore import context from mindspore.common import set_seed from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \ - TransformerTrainOneStepWithLossScaleCell + TransformerTrainOneStepWithLossScaleCell from src.config import cfg, transformer_net_cfg, transformer_net_cfg_gpu from src.dataset import create_transformer_dataset from src.lr_schedule import create_dynamic_lr set_seed(1) + def get_ms_timestamp(): t = time.time() return int(round(t * 1000)) + + time_stamp_init = False time_stamp_first = 0 + class LossCallBack(Callback): """ Monitor the loss in training. @@ -56,6 +60,7 @@ class LossCallBack(Callback): Args: per_print_times (int): Print loss every times. Default: 1. """ + def __init__(self, per_print_times=1, rank_id=0): super(LossCallBack, self).__init__() if not isinstance(per_print_times, int) or per_print_times < 0: @@ -116,6 +121,7 @@ def argparse_init(): return parser + def run_transformer_train(): """ Transformer training. @@ -128,6 +134,9 @@ def run_transformer_train(): context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False) + if args.device_target == "GPU": + # Enable graph kernel + context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion") if args.distribute == "true": if args.device_target == "Ascend": device_num = args.device_num @@ -204,5 +213,6 @@ def run_transformer_train(): model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=False) + if __name__ == '__main__': run_transformer_train()