fix tune failed for DynamicRnn

This commit is contained in:
LaiYongqiang 2021-07-30 10:15:33 +08:00
parent 052183871a
commit 2372a2e054
3 changed files with 14 additions and 4 deletions

View File

@ -32,7 +32,7 @@ from te_fusion.parallel_compilation import init_multi_process_env, start_ga_mult
get_finished_compilation_task
from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \
BuildType, adjust_custom_op_info
BuildType, adjust_custom_op_info, pack_op_args
from .tbe_job import TbeJob, JobStatus
PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"]
@ -516,8 +516,9 @@ def rl_tune_single_op(job: TbeJob):
tune_op_module_name = op_module_name + "@" + py_module_path
base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o"
from schedule_search.rl_online_tune import dispatch_single_tune_task
pack_args = pack_op_args(inputs, outputs, attrs)
res = dispatch_single_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, full_name,
tune_op_module_name, op_func_name, op_type, (inputs, outputs, attrs))
tune_op_module_name, op_func_name, op_type, pack_args)
res = _process_rl_tune_result(job, res)
return res

View File

@ -232,3 +232,11 @@ def adjust_custom_op_info(compute_op_info):
module_name, _ = os.path.splitext(file_name)
compute_op_info["py_module_path"] = py_module_path
compute_op_info["module_name"] = module_name
def pack_op_args(inputs, outputs, attrs):
"""
flatten inputs outputs attrs
"""
op_args = (inputs, outputs, attrs)
return [item for arg in op_args for item in arg]

View File

@ -312,9 +312,10 @@ class TbeTuner:
else:
job_type = RL_ONLINE
graph_id = 0
l1size = 0 # todo need to verify
l1size = 0
pack_op_args = [item for arg in op_args for item in arg]
ret = dispatch_single_tune_task(graph_id, task_id, l1size, base_kernel, kernel_name, full_name,
op_module_name + "@" + op_module_name, op_type, op_type, op_args)
op_module_name + "@" + op_module_name, op_type, op_type, pack_op_args)
self.module_list[op_module_name] = 1
self.fusion_need_sync += 1