fix warpctc bug when open graph_kernel flag

This commit is contained in:
zengzitao 2021-05-28 16:14:49 +08:00
parent ea93cc380a
commit 07752d7aaf
1 changed files with 4 additions and 1 deletions

View File

@ -36,6 +36,8 @@ from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
set_seed(1)
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
@ -88,6 +90,7 @@ def modelarts_pre_process():
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
config.save_checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.save_checkpoint_path)
@moxing_wrapper(pre_process=modelarts_pre_process)
def train():
"""Train function."""
@ -110,7 +113,7 @@ def train():
else:
device_num = 1
rank = 0
enable_graph_kernel = args_opt.platform == 'GPU'
enable_graph_kernel = config.device_target == 'GPU'
context.set_context(enable_graph_kernel=enable_graph_kernel)
max_captcha_digits = config.max_captcha_digits