open graph_kernel_flag in transformer network when platform is gpu

This commit is contained in:
zengzitao 2021-06-07 16:32:41 +08:00
parent e07c12a45e
commit ae3e8ba817
1 changed files with 12 additions and 2 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -34,19 +34,23 @@ from mindspore import context
from mindspore.common import set_seed
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
TransformerTrainOneStepWithLossScaleCell
TransformerTrainOneStepWithLossScaleCell
from src.config import cfg, transformer_net_cfg, transformer_net_cfg_gpu
from src.dataset import create_transformer_dataset
from src.lr_schedule import create_dynamic_lr
set_seed(1)
def get_ms_timestamp():
t = time.time()
return int(round(t * 1000))
time_stamp_init = False
time_stamp_first = 0
class LossCallBack(Callback):
"""
Monitor the loss in training.
@ -56,6 +60,7 @@ class LossCallBack(Callback):
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_id=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
@ -116,6 +121,7 @@ def argparse_init():
return parser
def run_transformer_train():
"""
Transformer training.
@ -128,6 +134,9 @@ def run_transformer_train():
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
if args.device_target == "GPU":
# Enable graph kernel
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
if args.distribute == "true":
if args.device_target == "Ascend":
device_num = args.device_num
@ -204,5 +213,6 @@ def run_transformer_train():
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=False)
if __name__ == '__main__':
run_transformer_train()