forked from mindspore-Ecosystem/mindspore
fix tune failed for DynamicRnn
This commit is contained in:
parent
052183871a
commit
2372a2e054
|
@ -32,7 +32,7 @@ from te_fusion.parallel_compilation import init_multi_process_env, start_ga_mult
|
|||
get_finished_compilation_task
|
||||
|
||||
from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \
|
||||
BuildType, adjust_custom_op_info
|
||||
BuildType, adjust_custom_op_info, pack_op_args
|
||||
from .tbe_job import TbeJob, JobStatus
|
||||
|
||||
PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"]
|
||||
|
@ -516,8 +516,9 @@ def rl_tune_single_op(job: TbeJob):
|
|||
tune_op_module_name = op_module_name + "@" + py_module_path
|
||||
base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o"
|
||||
from schedule_search.rl_online_tune import dispatch_single_tune_task
|
||||
pack_args = pack_op_args(inputs, outputs, attrs)
|
||||
res = dispatch_single_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, full_name,
|
||||
tune_op_module_name, op_func_name, op_type, (inputs, outputs, attrs))
|
||||
tune_op_module_name, op_func_name, op_type, pack_args)
|
||||
res = _process_rl_tune_result(job, res)
|
||||
return res
|
||||
|
||||
|
|
|
@ -232,3 +232,11 @@ def adjust_custom_op_info(compute_op_info):
|
|||
module_name, _ = os.path.splitext(file_name)
|
||||
compute_op_info["py_module_path"] = py_module_path
|
||||
compute_op_info["module_name"] = module_name
|
||||
|
||||
|
||||
def pack_op_args(inputs, outputs, attrs):
|
||||
"""
|
||||
flatten inputs outputs attrs
|
||||
"""
|
||||
op_args = (inputs, outputs, attrs)
|
||||
return [item for arg in op_args for item in arg]
|
||||
|
|
|
@ -312,9 +312,10 @@ class TbeTuner:
|
|||
else:
|
||||
job_type = RL_ONLINE
|
||||
graph_id = 0
|
||||
l1size = 0 # todo need to verify
|
||||
l1size = 0
|
||||
pack_op_args = [item for arg in op_args for item in arg]
|
||||
ret = dispatch_single_tune_task(graph_id, task_id, l1size, base_kernel, kernel_name, full_name,
|
||||
op_module_name + "@" + op_module_name, op_type, op_type, op_args)
|
||||
op_module_name + "@" + op_module_name, op_type, op_type, pack_op_args)
|
||||
|
||||
self.module_list[op_module_name] = 1
|
||||
self.fusion_need_sync += 1
|
||||
|
|
Loading…
Reference in New Issue