From 07752d7aaf402f576bccadd4f5841335dbdaac67 Mon Sep 17 00:00:00 2001 From: zengzitao Date: Fri, 28 May 2021 16:14:49 +0800 Subject: [PATCH] fix warpctc bug when open graph_kernel flag --- model_zoo/official/cv/warpctc/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/model_zoo/official/cv/warpctc/train.py b/model_zoo/official/cv/warpctc/train.py index df5a8bed6c9..001e06d0da5 100755 --- a/model_zoo/official/cv/warpctc/train.py +++ b/model_zoo/official/cv/warpctc/train.py @@ -36,6 +36,8 @@ from src.model_utils.config import config from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num set_seed(1) + + def modelarts_pre_process(): '''modelarts pre process function.''' def unzip(zip_file, save_dir): @@ -88,6 +90,7 @@ def modelarts_pre_process(): print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) config.save_checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.save_checkpoint_path) + @moxing_wrapper(pre_process=modelarts_pre_process) def train(): """Train function.""" @@ -110,7 +113,7 @@ def train(): else: device_num = 1 rank = 0 - enable_graph_kernel = args_opt.platform == 'GPU' + enable_graph_kernel = config.device_target == 'GPU' context.set_context(enable_graph_kernel=enable_graph_kernel) max_captcha_digits = config.max_captcha_digits