forked from mindspore-Ecosystem/mindspore
!19718 add reset op info for tune
Merge pull request !19718 from laiyongqiang/tune_reset_op_info
This commit is contained in:
commit
21316551d9
|
@ -29,6 +29,7 @@ from mindspore import log
|
||||||
from .tbe_common import check_kernel_info, TBEException
|
from .tbe_common import check_kernel_info, TBEException
|
||||||
from .helper import _op_select_format, _check_supported
|
from .helper import _op_select_format, _check_supported
|
||||||
|
|
||||||
|
|
||||||
# tune type
|
# tune type
|
||||||
NO_TUNE = "NO_TUNE"
|
NO_TUNE = "NO_TUNE"
|
||||||
GA_TUNE = "GA"
|
GA_TUNE = "GA"
|
||||||
|
@ -355,7 +356,10 @@ class TbeProcess:
|
||||||
log.error("Auto tune init failed, place check your hardware config or go back to normal compile!")
|
log.error("Auto tune init failed, place check your hardware config or go back to normal compile!")
|
||||||
self.tune_init = False
|
self.tune_init = False
|
||||||
return error_id
|
return error_id
|
||||||
|
self.__reset_op_info = self.get_reset_op_info()
|
||||||
self.__tuner.tune_init = True
|
self.__tuner.tune_init = True
|
||||||
|
json_info["reset_op_info"] = self.__reset_op_info
|
||||||
|
op_json = json.dumps(json_info)
|
||||||
self.__all_tune_tasks.append(task_id)
|
self.__all_tune_tasks.append(task_id)
|
||||||
self.__running_tune_tasks.append(task_id)
|
self.__running_tune_tasks.append(task_id)
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,8 @@ import datetime
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
from tbe.common.rl_bank.bank_manager import set_current_op_name
|
||||||
from te.platform.cce_conf import te_set_version
|
from te.platform.cce_conf import te_set_version
|
||||||
from te_fusion.fusion_manager import set_current_op_name
|
|
||||||
from te_fusion.fusion_util import fusion_op, dump_fusion_json
|
from te_fusion.fusion_util import fusion_op, dump_fusion_json
|
||||||
from te_fusion.parallel_compilation import init_multi_process_env, get_finished_compilation_task, \
|
from te_fusion.parallel_compilation import init_multi_process_env, get_finished_compilation_task, \
|
||||||
deinit_multi_process_env, start_ga_multi_process
|
deinit_multi_process_env, start_ga_multi_process
|
||||||
|
@ -331,13 +331,14 @@ class TbeTuner:
|
||||||
raise ValueError("Json string Errors, key:fusion_op not found.")
|
raise ValueError("Json string Errors, key:fusion_op not found.")
|
||||||
kernel_name = json_info["fusion_op"]["fusion_op_name"]
|
kernel_name = json_info["fusion_op"]["fusion_op_name"]
|
||||||
full_name = json_info["fusion_op"]["full_name"]
|
full_name = json_info["fusion_op"]["full_name"]
|
||||||
|
reset_op_info = json_info["reset_op_info"]
|
||||||
set_current_op_name(kernel_name)
|
set_current_op_name(kernel_name)
|
||||||
converted_json = fusion_to_fusion(json.dumps(json_info), tune_mode="RL")
|
converted_json = fusion_to_fusion(json.dumps(json_info), tune_mode="RL")
|
||||||
job_type = RL_COMPILE
|
job_type = RL_COMPILE
|
||||||
base_kernel = './kernel_meta/' + kernel_name + '.o'
|
base_kernel = './kernel_meta/' + kernel_name + '.o'
|
||||||
compile_info = None
|
compile_info = None
|
||||||
try:
|
try:
|
||||||
fusion_op(converted_json)
|
fusion_op(converted_json, reset_op_info=reset_op_info)
|
||||||
# pylint: disable=broad-except
|
# pylint: disable=broad-except
|
||||||
except Exception:
|
except Exception:
|
||||||
exc_type, exc_value, _ = sys.exc_info()
|
exc_type, exc_value, _ = sys.exc_info()
|
||||||
|
|
Loading…
Reference in New Issue