!13148 auto tune support dynamic shape

From: @liubuyu
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @kisnwang
This commit is contained in:
mindspore-ci-bot 2021-03-12 14:25:32 +08:00 committed by Gitee
commit 0a0f6064d3
3 changed files with 26 additions and 21 deletions

View File

@ -19,6 +19,8 @@ import sys
from te.platform.cce_conf import te_set_version
from te.platform.fusion_util import fusion_op
import te
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
# pylint: disable=wrong-import-position
from tbe_common import check_kernel_info, get_args, get_built_in_impl_path
build_in_impl_path = get_built_in_impl_path()
@ -50,13 +52,14 @@ def _replace_range(args):
range_item[index] = None
def build_op(build_type, json_str):
def build_op(build_type, json_str, tune_mode=None):
"""
call op functions with function name and input args json_str
Args:
build_type : op function name
json_str (str): op function input args
tune_mode (str): if use auto_tune
Raises:
Exception: If specific keyword is not found.
@ -93,8 +96,10 @@ def build_op(build_type, json_str):
else:
if is_dynamic_shape:
op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
op_module_name = "impl.dynamic." + op_name
else:
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
op_module_name = "impl." + op_name
# get function
if build_type == op_build:
if custom_flag:
@ -111,9 +116,14 @@ def build_op(build_type, json_str):
if is_dynamic_shape:
with te.op.dynamic():
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
if tune_mode is not None:
return (te.op.get_compile_info()), (inputs_args, outputs_args, attrs_args), op_module_name
return te.op.get_compile_info()
else:
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
if tune_mode is not None:
return res, (inputs_args, outputs_args, attrs_args), op_module_name
return res
except Exception as e:
raise RuntimeError(e)
@ -149,7 +159,7 @@ def compile_with_json(json_str):
if "fusion_op" in json_info:
ret = compile_fusion_op(json_str)
else:
ret = build_op(op_build, json_str)
ret = build_op(op_build, json_str, None)
return ret

View File

@ -326,16 +326,16 @@ class TbeProcess:
self.__running_tune_tasks.append(task_id)
if tune_mode == RL_TUNE:
ret, job_type = self.__tuner.rl_tune(task_id, op_json)
ret, job_type, compile_info = self.__tuner.rl_tune(task_id, op_json)
if job_type is RL_OFFLINE or job_type is RL_ONLINE:
if not ret:
# offline and online hit will return false
res = task_id, "Success", "Success"
res = task_id, "Success", compile_info
self.__finish_tune_task.append(res)
self.__running_tune_tasks.remove(task_id)
elif job_type is RL_COMPILE:
if not ret:
res = task_id, "Fail", "Fail"
res = task_id, "Fail", compile_info
self.__finish_tune_task.append(res)
self.__running_tune_tasks.remove(task_id)
elif tune_mode == GA_TUNE:
@ -384,13 +384,14 @@ class TbeProcess:
for item in ret:
task_id = item['task_id']
status_code = item['status_code']
compile_info = item["op_res"] if "op_res" in item else "{}"
res = None
if status_code == 0:
res = task_id, "Success", "Success"
res = task_id, "Success", compile_info
else:
self.__failed_tune_task.append(task_id)
log.info("task_id:{}, json:{}".format(task_id, self.__task_info[task_id]))
res = task_id, "Failed", "Failed"
res = task_id, "Failed", compile_info
self.__finish_tune_task.append(res)
self.__running_tune_tasks.remove(task_id)
ret = self.__finish_tune_task.pop()

View File

@ -27,13 +27,14 @@ import auto_tune
from schedule_search.rl_online_tune import rl_tune_init, dispatch_fusion_tune_task, dispatch_single_tune_task, \
rl_tune_deinit
from mindspore import log
from .tbe_common import get_args
from .compiler import build_op
from .re_construct_json import single_to_fusion, fusion_to_fusion
TE_LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"]
RL_COMPILE = "RL_COMPILE"
RL_OFFLINE = "RL_OFFLINE"
RL_ONLINE = "RL_ONLINE"
OP_BUILD = "compile"
PLATFORM_FLAG = ["ascend310", "ascend910", "Hi3796CV300ES", "ascend710", "ascend610", "Hi3796CV300CS", "SD3403"]
@ -285,27 +286,20 @@ class TbeTuner:
converted_json = single_to_fusion(json.dumps(json_info), tune_mode="RL")
op_type = json_info['op_info']['name']
kernel_name = json_info['op_info']['kernel_name']
op_module = __import__("impl." + op_type, globals(), locals(), [op_type], 0)
op_module_name = "impl." + op_type
py_fn_name = json_info['op_info']['name']
op_func = getattr(op_module, py_fn_name, None)
tune_mode = "RL"
set_current_op_name(kernel_name)
inputs_args = get_args(json_info['op_info'], 'inputs')
outputs_args = get_args(json_info['op_info'], 'outputs')
attrs_args = get_args(json_info['op_info'], 'attrs')
op_args = inputs_args, outputs_args, attrs_args
# todo build with build_single_op_from_c
base_kernel = './kernel_meta/' + kernel_name + '.o'
job_type = RL_COMPILE
compile_info = "{}"
try:
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
compile_info, op_args, op_module_name = build_op(OP_BUILD, json.dumps(json_info), tune_mode)
# pylint: disable=broad-except
except Exception:
exc_type, exc_value, _ = sys.exc_info()
log.error(
"exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc()))
return False, job_type
return False, job_type, compile_info
if self.offline_tune:
job_type = RL_OFFLINE
dump_fusion_json(converted_json, self.offline_dump_path)
@ -318,7 +312,7 @@ class TbeTuner:
self.module_list[op_module_name] = 1
self.fusion_need_sync += 1
return ret, job_type
return ret, job_type, json.dumps(compile_info)
def get_op_module_names(self, json_info):
"""