add reset op info for tune

This commit is contained in:
LaiYongqiang 2021-07-08 18:56:07 +08:00
parent 1ac696d044
commit b61868b8bf
2 changed files with 7 additions and 2 deletions

View File

@ -29,6 +29,7 @@ from mindspore import log
from .tbe_common import check_kernel_info, TBEException from .tbe_common import check_kernel_info, TBEException
from .helper import _op_select_format, _check_supported from .helper import _op_select_format, _check_supported
# tune type # tune type
NO_TUNE = "NO_TUNE" NO_TUNE = "NO_TUNE"
GA_TUNE = "GA" GA_TUNE = "GA"
@ -355,7 +356,10 @@ class TbeProcess:
log.error("Auto tune init failed, place check your hardware config or go back to normal compile!") log.error("Auto tune init failed, place check your hardware config or go back to normal compile!")
self.tune_init = False self.tune_init = False
return error_id return error_id
self.__reset_op_info = self.get_reset_op_info()
self.__tuner.tune_init = True self.__tuner.tune_init = True
json_info["reset_op_info"] = self.__reset_op_info
op_json = json.dumps(json_info)
self.__all_tune_tasks.append(task_id) self.__all_tune_tasks.append(task_id)
self.__running_tune_tasks.append(task_id) self.__running_tune_tasks.append(task_id)

View File

@ -18,8 +18,8 @@ import datetime
import json import json
import sys import sys
import traceback import traceback
from tbe.common.rl_bank.bank_manager import set_current_op_name
from te.platform.cce_conf import te_set_version from te.platform.cce_conf import te_set_version
from te_fusion.fusion_manager import set_current_op_name
from te_fusion.fusion_util import fusion_op, dump_fusion_json from te_fusion.fusion_util import fusion_op, dump_fusion_json
from te_fusion.parallel_compilation import init_multi_process_env, get_finished_compilation_task, \ from te_fusion.parallel_compilation import init_multi_process_env, get_finished_compilation_task, \
deinit_multi_process_env, start_ga_multi_process deinit_multi_process_env, start_ga_multi_process
@ -331,13 +331,14 @@ class TbeTuner:
raise ValueError("Json string Errors, key:fusion_op not found.") raise ValueError("Json string Errors, key:fusion_op not found.")
kernel_name = json_info["fusion_op"]["fusion_op_name"] kernel_name = json_info["fusion_op"]["fusion_op_name"]
full_name = json_info["fusion_op"]["full_name"] full_name = json_info["fusion_op"]["full_name"]
reset_op_info = json_info["reset_op_info"]
set_current_op_name(kernel_name) set_current_op_name(kernel_name)
converted_json = fusion_to_fusion(json.dumps(json_info), tune_mode="RL") converted_json = fusion_to_fusion(json.dumps(json_info), tune_mode="RL")
job_type = RL_COMPILE job_type = RL_COMPILE
base_kernel = './kernel_meta/' + kernel_name + '.o' base_kernel = './kernel_meta/' + kernel_name + '.o'
compile_info = None compile_info = None
try: try:
fusion_op(converted_json) fusion_op(converted_json, reset_op_info=reset_op_info)
# pylint: disable=broad-except # pylint: disable=broad-except
except Exception: except Exception:
exc_type, exc_value, _ = sys.exc_info() exc_type, exc_value, _ = sys.exc_info()