forked from mindspore-Ecosystem/mindspore
!13148 auto tune support dynamic shape
From: @liubuyu Reviewed-by: @zhoufeng54,@kisnwang Signed-off-by: @kisnwang
This commit is contained in:
commit
0a0f6064d3
|
@ -19,6 +19,8 @@ import sys
|
|||
from te.platform.cce_conf import te_set_version
|
||||
from te.platform.fusion_util import fusion_op
|
||||
import te
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
# pylint: disable=wrong-import-position
|
||||
from tbe_common import check_kernel_info, get_args, get_built_in_impl_path
|
||||
|
||||
build_in_impl_path = get_built_in_impl_path()
|
||||
|
@ -50,13 +52,14 @@ def _replace_range(args):
|
|||
range_item[index] = None
|
||||
|
||||
|
||||
def build_op(build_type, json_str):
|
||||
def build_op(build_type, json_str, tune_mode=None):
|
||||
"""
|
||||
call op functions with function name and input args json_str
|
||||
|
||||
Args:
|
||||
build_type : op function name
|
||||
json_str (str): op function input args
|
||||
tune_mode (str): if use auto_tune
|
||||
|
||||
Raises:
|
||||
Exception: If specific keyword is not found.
|
||||
|
@ -93,8 +96,10 @@ def build_op(build_type, json_str):
|
|||
else:
|
||||
if is_dynamic_shape:
|
||||
op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
|
||||
op_module_name = "impl.dynamic." + op_name
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
op_module_name = "impl." + op_name
|
||||
# get function
|
||||
if build_type == op_build:
|
||||
if custom_flag:
|
||||
|
@ -111,9 +116,14 @@ def build_op(build_type, json_str):
|
|||
if is_dynamic_shape:
|
||||
with te.op.dynamic():
|
||||
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
if tune_mode is not None:
|
||||
return (te.op.get_compile_info()), (inputs_args, outputs_args, attrs_args), op_module_name
|
||||
return te.op.get_compile_info()
|
||||
else:
|
||||
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
if tune_mode is not None:
|
||||
return res, (inputs_args, outputs_args, attrs_args), op_module_name
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(e)
|
||||
|
@ -149,7 +159,7 @@ def compile_with_json(json_str):
|
|||
if "fusion_op" in json_info:
|
||||
ret = compile_fusion_op(json_str)
|
||||
else:
|
||||
ret = build_op(op_build, json_str)
|
||||
ret = build_op(op_build, json_str, None)
|
||||
return ret
|
||||
|
||||
|
||||
|
|
|
@ -326,16 +326,16 @@ class TbeProcess:
|
|||
self.__running_tune_tasks.append(task_id)
|
||||
|
||||
if tune_mode == RL_TUNE:
|
||||
ret, job_type = self.__tuner.rl_tune(task_id, op_json)
|
||||
ret, job_type, compile_info = self.__tuner.rl_tune(task_id, op_json)
|
||||
if job_type is RL_OFFLINE or job_type is RL_ONLINE:
|
||||
if not ret:
|
||||
# offline and online hit will return false
|
||||
res = task_id, "Success", "Success"
|
||||
res = task_id, "Success", compile_info
|
||||
self.__finish_tune_task.append(res)
|
||||
self.__running_tune_tasks.remove(task_id)
|
||||
elif job_type is RL_COMPILE:
|
||||
if not ret:
|
||||
res = task_id, "Fail", "Fail"
|
||||
res = task_id, "Fail", compile_info
|
||||
self.__finish_tune_task.append(res)
|
||||
self.__running_tune_tasks.remove(task_id)
|
||||
elif tune_mode == GA_TUNE:
|
||||
|
@ -384,13 +384,14 @@ class TbeProcess:
|
|||
for item in ret:
|
||||
task_id = item['task_id']
|
||||
status_code = item['status_code']
|
||||
compile_info = item["op_res"] if "op_res" in item else "{}"
|
||||
res = None
|
||||
if status_code == 0:
|
||||
res = task_id, "Success", "Success"
|
||||
res = task_id, "Success", compile_info
|
||||
else:
|
||||
self.__failed_tune_task.append(task_id)
|
||||
log.info("task_id:{}, json:{}".format(task_id, self.__task_info[task_id]))
|
||||
res = task_id, "Failed", "Failed"
|
||||
res = task_id, "Failed", compile_info
|
||||
self.__finish_tune_task.append(res)
|
||||
self.__running_tune_tasks.remove(task_id)
|
||||
ret = self.__finish_tune_task.pop()
|
||||
|
|
|
@ -27,13 +27,14 @@ import auto_tune
|
|||
from schedule_search.rl_online_tune import rl_tune_init, dispatch_fusion_tune_task, dispatch_single_tune_task, \
|
||||
rl_tune_deinit
|
||||
from mindspore import log
|
||||
from .tbe_common import get_args
|
||||
from .compiler import build_op
|
||||
from .re_construct_json import single_to_fusion, fusion_to_fusion
|
||||
|
||||
TE_LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
RL_COMPILE = "RL_COMPILE"
|
||||
RL_OFFLINE = "RL_OFFLINE"
|
||||
RL_ONLINE = "RL_ONLINE"
|
||||
OP_BUILD = "compile"
|
||||
|
||||
PLATFORM_FLAG = ["ascend310", "ascend910", "Hi3796CV300ES", "ascend710", "ascend610", "Hi3796CV300CS", "SD3403"]
|
||||
|
||||
|
@ -285,27 +286,20 @@ class TbeTuner:
|
|||
converted_json = single_to_fusion(json.dumps(json_info), tune_mode="RL")
|
||||
op_type = json_info['op_info']['name']
|
||||
kernel_name = json_info['op_info']['kernel_name']
|
||||
op_module = __import__("impl." + op_type, globals(), locals(), [op_type], 0)
|
||||
op_module_name = "impl." + op_type
|
||||
py_fn_name = json_info['op_info']['name']
|
||||
op_func = getattr(op_module, py_fn_name, None)
|
||||
|
||||
tune_mode = "RL"
|
||||
set_current_op_name(kernel_name)
|
||||
inputs_args = get_args(json_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(json_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(json_info['op_info'], 'attrs')
|
||||
op_args = inputs_args, outputs_args, attrs_args
|
||||
# todo build with build_single_op_from_c
|
||||
base_kernel = './kernel_meta/' + kernel_name + '.o'
|
||||
job_type = RL_COMPILE
|
||||
compile_info = "{}"
|
||||
try:
|
||||
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
compile_info, op_args, op_module_name = build_op(OP_BUILD, json.dumps(json_info), tune_mode)
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
exc_type, exc_value, _ = sys.exc_info()
|
||||
log.error(
|
||||
"exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc()))
|
||||
return False, job_type
|
||||
return False, job_type, compile_info
|
||||
if self.offline_tune:
|
||||
job_type = RL_OFFLINE
|
||||
dump_fusion_json(converted_json, self.offline_dump_path)
|
||||
|
@ -318,7 +312,7 @@ class TbeTuner:
|
|||
|
||||
self.module_list[op_module_name] = 1
|
||||
self.fusion_need_sync += 1
|
||||
return ret, job_type
|
||||
return ret, job_type, json.dumps(compile_info)
|
||||
|
||||
def get_op_module_names(self, json_info):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue