forked from mindspore-Ecosystem/mindspore
fix warpctc bug when open graph_kernel flag
This commit is contained in:
parent
ea93cc380a
commit
07752d7aaf
|
@ -36,6 +36,8 @@ from src.model_utils.config import config
|
|||
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
|
@ -88,6 +90,7 @@ def modelarts_pre_process():
|
|||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
config.save_checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.save_checkpoint_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
"""Train function."""
|
||||
|
@ -110,7 +113,7 @@ def train():
|
|||
else:
|
||||
device_num = 1
|
||||
rank = 0
|
||||
enable_graph_kernel = args_opt.platform == 'GPU'
|
||||
enable_graph_kernel = config.device_target == 'GPU'
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
|
||||
max_captcha_digits = config.max_captcha_digits
|
||||
|
|
Loading…
Reference in New Issue