forked from mindspore-Ecosystem/mindspore
open graph_kernel_flag in transformer network when platform is gpu
This commit is contained in:
parent
e07c12a45e
commit
ae3e8ba817
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -34,19 +34,23 @@ from mindspore import context
|
|||
from mindspore.common import set_seed
|
||||
|
||||
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
|
||||
TransformerTrainOneStepWithLossScaleCell
|
||||
TransformerTrainOneStepWithLossScaleCell
|
||||
from src.config import cfg, transformer_net_cfg, transformer_net_cfg_gpu
|
||||
from src.dataset import create_transformer_dataset
|
||||
from src.lr_schedule import create_dynamic_lr
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
|
||||
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
@ -56,6 +60,7 @@ class LossCallBack(Callback):
|
|||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1, rank_id=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
|
@ -116,6 +121,7 @@ def argparse_init():
|
|||
|
||||
return parser
|
||||
|
||||
|
||||
def run_transformer_train():
|
||||
"""
|
||||
Transformer training.
|
||||
|
@ -128,6 +134,9 @@ def run_transformer_train():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
|
||||
|
||||
if args.device_target == "GPU":
|
||||
# Enable graph kernel
|
||||
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
|
||||
if args.distribute == "true":
|
||||
if args.device_target == "Ascend":
|
||||
device_num = args.device_num
|
||||
|
@ -204,5 +213,6 @@ def run_transformer_train():
|
|||
|
||||
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_transformer_train()
|
||||
|
|
Loading…
Reference in New Issue