forked from mindspore-Ecosystem/mindspore
new TBE compile server
This commit is contained in:
parent
26cf52e99d
commit
ba8fdcfeae
|
@ -0,0 +1,587 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""tbe adapter to adapt te/topi/auto-tune python api """
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from tbe.common.rl_bank.bank_manager import set_current_op_name
|
||||
from te.platform.cce_conf import te_set_version
|
||||
from te.platform.cce_policy import set_L1_info
|
||||
from te_fusion.compile_task_manager import dispatch_prebuild_task, dispatch_single_op_compile_task, import_py_module, \
|
||||
dispatch_fusion_op_compile_task, dispatch_autotune_task, sync_op_tune_params
|
||||
from te_fusion.compile_task_manager import sync_syspath
|
||||
from te_fusion.fusion_manager import check_supported, call_op_func, clear_fusion_params, check_op_impl_mode, \
|
||||
save_op_params, build_single_op_from_c, op_params_to_json
|
||||
from te_fusion.fusion_util import dump_fusion_json, fusion_op
|
||||
from te_fusion.parallel_compilation import init_multi_process_env, start_ga_multi_process, deinit_multi_process_env, \
|
||||
get_finished_compilation_task
|
||||
|
||||
from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \
|
||||
BuildType, adjust_custom_op_info
|
||||
from .tbe_job import TbeJob, JobStatus
|
||||
|
||||
PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"]
|
||||
|
||||
|
||||
def _tune_init(job: TbeJob):
|
||||
"""
|
||||
Tune Initialize
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
aoto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
|
||||
offline_tune = job.content["SocInfo"]["offlineTune"]
|
||||
op_bank_update = job.content["SocInfo"]["op_bank_update"]
|
||||
tune_dump_path = job.content["TuneInfo"]["tune_dump_path"]
|
||||
tune_bank_path = job.content["TuneInfo"]["tune_bank_path"]
|
||||
need_ga = bool("GA" in aoto_tiling_mode)
|
||||
need_rl = bool("RL" in aoto_tiling_mode)
|
||||
if offline_tune:
|
||||
os.environ["ENABLE_TUNE_DUMP"] = "TRUE"
|
||||
if op_bank_update:
|
||||
sync_op_tune_params("tbe.common.tiling.tiling_api", "reset_repository", False, "")
|
||||
if need_ga or need_rl or offline_tune:
|
||||
try:
|
||||
import auto_tune.auto_tune_main as at_atm
|
||||
from schedule_search.rl_online_tune import rl_tune_init # pylint: disable=unused-import
|
||||
if need_ga:
|
||||
res = at_atm.check_soc_version()
|
||||
if not res:
|
||||
job.error("check soc version failed in tune init")
|
||||
job.error("GATune run Failed. Run .o Failed, because soc_version doesn't match the device")
|
||||
return False
|
||||
|
||||
except ImportError:
|
||||
msg = "TBEException", \
|
||||
"No module named `auto_tune` or `schedule_search`. If you want tune your op's performance," \
|
||||
"please configure `auto_tune` or `schedule_search` related environment variables." \
|
||||
"Try to set the following environment variables:" \
|
||||
"export fwk_path=/usr/local/Ascend/fwkacllib" \
|
||||
"export PYTHONPATH=${fwk_path}/python/site-packages:$PYTHONPATH" \
|
||||
"export PYTHONPATH=${fwk_path}/python/site-packages/auto_tune.egg/auto_tune:$PYTHONPATH" \
|
||||
"export PYTHONPATH=${fwk_path}/python/site-packages/schedule_search.egg:$PYTHONPATH"
|
||||
job.error(msg)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
if tune_dump_path:
|
||||
os.environ["TUNE_DUMP_PATH"] = str(tune_dump_path)
|
||||
if tune_bank_path:
|
||||
os.environ["TUNE_BANK_PATH"] = str(tune_bank_path)
|
||||
res = _creating_custom_path(job)
|
||||
return res
|
||||
|
||||
|
||||
def __directory_creation(path, concat_path):
|
||||
"""
|
||||
Create directory
|
||||
"""
|
||||
path = os.path.join(path, concat_path)
|
||||
if not os.path.isdir(path):
|
||||
os.makedirs(path, 0o750)
|
||||
return path
|
||||
|
||||
|
||||
def __creating_default_custom_path(auto_tiling_mode, base_custom_path):
|
||||
"""
|
||||
Create default custom path
|
||||
"""
|
||||
base_custom_path = __directory_creation(base_custom_path, "data")
|
||||
tune_flag = []
|
||||
if "RL" in auto_tiling_mode:
|
||||
tune_flag.append("rl")
|
||||
if "GA" in auto_tiling_mode:
|
||||
tune_flag.append("tiling")
|
||||
|
||||
for tune_path in tune_flag:
|
||||
real_path = __directory_creation(base_custom_path, tune_path)
|
||||
for soc_version in PLATFORM_FLAG:
|
||||
final_path = __directory_creation(real_path, soc_version)
|
||||
final_path = __directory_creation(final_path, "custom")
|
||||
return True
|
||||
|
||||
|
||||
def _creating_custom_path(job):
|
||||
"""
|
||||
Create custom path
|
||||
"""
|
||||
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
|
||||
if "NO_TUNE" in auto_tiling_mode:
|
||||
return True
|
||||
|
||||
base_custom_path = job.content["TuneInfo"]["tune_bank_path"]
|
||||
tune_bank_flag = True
|
||||
if not base_custom_path:
|
||||
import auto_tune
|
||||
base_custom_path = os.path.dirname(os.path.realpath(auto_tune.__file__))
|
||||
base_custom_path = os.path.realpath(os.path.join(base_custom_path, "../../../"))
|
||||
tune_bank_flag = False
|
||||
|
||||
if not os.path.isdir(base_custom_path):
|
||||
job.error("Check whether the tuning path [{}] exists.".format(base_custom_path))
|
||||
return False
|
||||
if not os.access(base_custom_path, os.R_OK | os.W_OK | os.X_OK):
|
||||
job.error("Check whether the permission on the tuning path [{}] is correct.".format(base_custom_path))
|
||||
return False
|
||||
|
||||
if not tune_bank_flag:
|
||||
return __creating_default_custom_path(auto_tiling_mode, base_custom_path)
|
||||
return True
|
||||
|
||||
|
||||
def _parallel_compilation_init(initialize: TbeJob):
|
||||
"""
|
||||
Tbe parallel compilation initialize
|
||||
:param initialize:
|
||||
:return:
|
||||
"""
|
||||
os.environ["TE_PARALLEL_COMPILER"] = str(initialize.content["process_num"])
|
||||
embedding = False
|
||||
soc_info = get_soc_info(initialize.content)
|
||||
auto_tiling_mode = initialize.content["SocInfo"]["autoTilingMode"]
|
||||
offline_tune = initialize.content["SocInfo"]["offlineTune"]
|
||||
global_loglevel = initialize.content["log_level"]
|
||||
enable_event = initialize.content["enable_event"]
|
||||
pid_str = os.getpid()
|
||||
time_str = datetime.now().strftime('%Y%m%d_%H%M%S%f')[:-3]
|
||||
pid_ts = "{}_pid{}".format(time_str, pid_str)
|
||||
ret = init_multi_process_env(embedding, soc_info, auto_tiling_mode, global_loglevel, enable_event, pid_ts)
|
||||
if ret is None:
|
||||
initialize.error("Init multiprocess env failed")
|
||||
return False
|
||||
initialize.info("Init multiprocess env success with {} process".format(ret[0]))
|
||||
if "RL" in auto_tiling_mode or offline_tune:
|
||||
res_queue = ret[1]
|
||||
live_checker = ret[2]
|
||||
termin_event = ret[3]
|
||||
from schedule_search.rl_online_tune import rl_tune_init
|
||||
ret = rl_tune_init(soc_info, res_queue, live_checker, termin_event, global_loglevel, pid_ts)
|
||||
if not ret:
|
||||
initialize.error("RL env init failed!")
|
||||
return False
|
||||
initialize.info("RL Tune init success.")
|
||||
if "GA" in auto_tiling_mode:
|
||||
start_ga_multi_process(auto_tiling_mode)
|
||||
initialize.info("GA Tune init success.")
|
||||
return True
|
||||
|
||||
|
||||
def tbe_initialize(job: TbeJob):
|
||||
"""
|
||||
Tbe Initialize
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
os.environ["CONTEXT_MODELCOMPILING"] = "TRUE"
|
||||
soc_info = get_soc_info(job.content)
|
||||
res = te_set_version(*soc_info)
|
||||
if not res:
|
||||
job.error("Set version failed")
|
||||
res = _tune_init(job)
|
||||
if not res:
|
||||
job.error("Tune init failed")
|
||||
res = _parallel_compilation_init(job)
|
||||
if not res:
|
||||
job.error("Parallel compilation failed")
|
||||
job.result = "Success"
|
||||
return res
|
||||
|
||||
|
||||
def get_auto_tune_support_op_list(job: TbeJob):
|
||||
"""
|
||||
Get GA tune supported op list
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
from auto_tune_main import enable_auto_tune_support
|
||||
auto_tune_op_list = enable_auto_tune_support()
|
||||
job.info("auto tune GA support ops list:{}".format(auto_tune_op_list))
|
||||
return auto_tune_op_list
|
||||
|
||||
|
||||
def _normalize_module_name(module_name, py_module_path):
|
||||
"""
|
||||
Normalize module name
|
||||
:param module_name:
|
||||
:param py_module_path:
|
||||
:return:
|
||||
"""
|
||||
if py_module_path not in sys.path:
|
||||
sys.path.insert(0, py_module_path)
|
||||
sync_syspath(py_module_path)
|
||||
|
||||
|
||||
def check_support(job: TbeJob):
|
||||
"""
|
||||
Check support
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
op_compute_info_list = get_compute_op_list(job.content)
|
||||
if len(op_compute_info_list) != 1:
|
||||
job.error("Invalid op compute num ({}) in check_support".format(len(op_compute_info_list)))
|
||||
return False
|
||||
compute_op_info = op_compute_info_list[0]
|
||||
adjust_custom_op_info(compute_op_info)
|
||||
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
||||
op_func_name = compute_op_info["func_name"]
|
||||
if op_func_name in ("resize_nearest_neighbor_v2_grad_d", "resize_bilinear_v2_grad"):
|
||||
attrs.pop(-2)
|
||||
op_module_name = compute_op_info["module_name"]
|
||||
py_module_path = compute_op_info["py_module_path"]
|
||||
_normalize_module_name(op_module_name, py_module_path)
|
||||
func_name = "check_supported"
|
||||
fuzz_build = compute_op_info["build_type"] == BuildType.FUZZILY.value
|
||||
res = check_supported(op_module_name, func_name, (inputs, outputs, attrs, fuzz_build))
|
||||
if isinstance(res, tuple):
|
||||
result, reason = res
|
||||
result_str = str(result)
|
||||
if result_str == "True":
|
||||
job.result = "FULLY_SUPPORTED"
|
||||
elif result_str == "False":
|
||||
job.result = "NOT_SUPPORTED"
|
||||
elif result_str == "Unknown":
|
||||
job.result = "PARTIALLY_SUPPORTED"
|
||||
job.info("op module {} check support result is partially supported".format(op_module_name))
|
||||
else:
|
||||
job.result = "NOT_SUPPORTED"
|
||||
job.info("op module {} check support result is {}, not supported".format(op_module_name, result_str))
|
||||
if reason:
|
||||
job.info("Unsupported reason is {}".format(reason))
|
||||
else:
|
||||
job.result = str(res)
|
||||
return True
|
||||
|
||||
|
||||
def select_op_format(job: TbeJob):
|
||||
"""
|
||||
Select op format
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
compute_op_info_list = get_compute_op_list(job.content)
|
||||
if len(compute_op_info_list) != 1:
|
||||
job.error("Invalid op compute num ({}) in check_support".format(len(compute_op_info_list)))
|
||||
return False
|
||||
compute_op_info = compute_op_info_list[0]
|
||||
adjust_custom_op_info(compute_op_info)
|
||||
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
||||
op_module_name = compute_op_info["module_name"]
|
||||
py_module_path = compute_op_info["py_module_path"]
|
||||
_normalize_module_name(op_module_name, py_module_path)
|
||||
op_func_name = "op_select_format"
|
||||
res = call_op_func(op_module_name, op_func_name, (inputs, outputs, attrs))
|
||||
job.result = str(res)
|
||||
return True
|
||||
|
||||
|
||||
def parallel_pre_compile_op(job: TbeJob):
|
||||
"""
|
||||
Parallel pre compile op
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
compute_op_info_list = get_compute_op_list(job.content)
|
||||
if len(compute_op_info_list) != 1:
|
||||
job.error("Invalid op compute num ({}) in pre compile op".format(len(compute_op_info_list)))
|
||||
return False
|
||||
compute_op_info = compute_op_info_list[0]
|
||||
adjust_custom_op_info(compute_op_info)
|
||||
_pre_build_compute_op_info(compute_op_info, job)
|
||||
return True
|
||||
|
||||
|
||||
def _pre_build_compute_op_info(compute_op, job):
|
||||
"""
|
||||
Prebuild by compute op info
|
||||
:param compute_op:
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
l1_size = job.content["l1_size"]
|
||||
if l1_size != -1:
|
||||
set_L1_info("op_L1_space", -1)
|
||||
inputs, outputs, attrs = assemble_op_args(compute_op)
|
||||
op_module_name = compute_op["module_name"]
|
||||
py_module_path = compute_op["py_module_path"]
|
||||
op_func_name = compute_op["func_name"]
|
||||
op_type = compute_op["type"]
|
||||
op_name = compute_op["op_name"]
|
||||
save_op_params(op_name, "prebuild", (outputs, attrs))
|
||||
l1_size = job.content["l1_size"]
|
||||
set_L1_info("op_L1_space", l1_size)
|
||||
_normalize_module_name(op_module_name, py_module_path)
|
||||
unknown_shape = compute_op["unknown_shape"]
|
||||
int64_mode = compute_op["int64mode"]
|
||||
dynamic_compile_static = compute_op["dynamic_compile_static"]
|
||||
res = check_op_impl_mode(op_module_name, op_func_name, op_type, inputs, outputs, unknown_shape,
|
||||
dynamic_compile_static)
|
||||
op_impl_mode = job.content["SocInfo"]["op_impl_mode"]
|
||||
op_impl_mode_list = job.content["SocInfo"]["op_impl_mode_list"]
|
||||
if not res:
|
||||
if op_impl_mode_list:
|
||||
job.warning("The op {} do NOT support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode))
|
||||
else:
|
||||
job.info("OpType {} support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode))
|
||||
options = get_options_info(job.content)
|
||||
dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name, unknown_shape,
|
||||
(inputs, outputs, attrs, options), int64_mode, dynamic_compile_static, job.rl_tune_switch,
|
||||
job.rl_tune_list, job.pass_list, job.op_tune_switch, job.op_tune_list)
|
||||
|
||||
|
||||
def get_prebuild_output(op_name):
|
||||
""" get prebuild output """
|
||||
params_str = op_params_to_json(op_name)
|
||||
try:
|
||||
res = json.loads(params_str)
|
||||
except ValueError:
|
||||
res = {}
|
||||
return res
|
||||
|
||||
|
||||
def do_fuzz_build_tbe_op(job: TbeJob):
|
||||
"""
|
||||
Fuzzy build op
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
job.result = "NOT_CHANGED"
|
||||
return True
|
||||
|
||||
|
||||
def _dump_fusion_op_info_to_json_file(job: TbeJob):
|
||||
"""
|
||||
Dump fusion op info to json file
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
if not job.sys_para_debug_path or job.sys_para_debug_path == "\0":
|
||||
return
|
||||
dump_fusion_json(json.dumps(job.content), job.sys_para_debug_path)
|
||||
|
||||
|
||||
def build_single_pre_op(job: TbeJob):
|
||||
"""
|
||||
Build single op
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
before_build_process(job)
|
||||
compute_op_info_list = get_compute_op_list(job.content)
|
||||
if len(compute_op_info_list) != 1:
|
||||
job.error("Invalid op compute num ({}) in build single op".format(len(compute_op_info_list)))
|
||||
return False
|
||||
compute_op_info = compute_op_info_list[0]
|
||||
adjust_custom_op_info(compute_op_info)
|
||||
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
||||
op_type = compute_op_info["type"]
|
||||
l1_size = job.content["l1_size"]
|
||||
op_module_name = compute_op_info["module_name"]
|
||||
op_kernel_name = compute_op_info["op_name"]
|
||||
py_module_path = compute_op_info["py_module_path"]
|
||||
op_func_name = compute_op_info["func_name"]
|
||||
_normalize_module_name(op_module_name, py_module_path)
|
||||
unknown_shape = compute_op_info["unknown_shape"]
|
||||
int64_mode = compute_op_info["int64mode"]
|
||||
dynamic_compile_static = compute_op_info["dynamic_compile_static"]
|
||||
op_pattern = compute_op_info["pattern"]
|
||||
options = get_options_info(job.content)
|
||||
fuzz_build_info = get_fuzz_build_info(job.content)
|
||||
dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name,
|
||||
op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode,
|
||||
None, None, dynamic_compile_static, op_pattern, json.dumps(fuzz_build_info),
|
||||
job.rl_tune_switch, job.rl_tune_list, job.pass_list, job.op_tune_switch,
|
||||
job.op_tune_list)
|
||||
return True
|
||||
|
||||
|
||||
def before_build_process(job: TbeJob):
|
||||
"""
|
||||
Processing before build
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
l1_size = job.content["l1_size"]
|
||||
set_L1_info("op_L1_space", l1_size)
|
||||
_dump_fusion_op_info_to_json_file(job)
|
||||
offline_tune = job.sys_offline_tune
|
||||
if offline_tune:
|
||||
dump_fusion_json(json.dumps(job.content), job.sys_tune_dump_path)
|
||||
|
||||
|
||||
def sync_fusion_env(fusion_need_sync, module_list):
|
||||
"""
|
||||
Sync fusion env
|
||||
:param fusion_need_sync:
|
||||
:param module_list:
|
||||
:return:
|
||||
"""
|
||||
if fusion_need_sync == 0:
|
||||
return True
|
||||
|
||||
module_using = []
|
||||
for key, value in module_list.items():
|
||||
if value > 0:
|
||||
module_using.append(str(key))
|
||||
module_list[key] = 0
|
||||
|
||||
module_str = ",".join(module_using)
|
||||
import_py_module(module_str)
|
||||
return True
|
||||
|
||||
|
||||
def parallel_compile_fusion_op(job: TbeJob):
|
||||
"""
|
||||
Compile fusion op in parallel compiler
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
l1_size = job.content["l1_size"]
|
||||
options = get_options_info(job.content)
|
||||
op_kernel_name = job.content["fusion_op_name"]
|
||||
dispatch_fusion_op_compile_task(job.source_id, job.id, l1_size, json.dumps(job.content), op_kernel_name, None, None,
|
||||
options, job.rl_tune_switch, job.rl_tune_list, job.pass_list,
|
||||
job.op_tune_switch, job.op_tune_list)
|
||||
return True
|
||||
|
||||
|
||||
def ga_tune(job: TbeJob):
|
||||
"""
|
||||
GA tune
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
l1_size = job.content["l1_size"]
|
||||
op_kernel_name = job.content["fusion_op_name"]
|
||||
dispatch_autotune_task(job.source_id, job.id, l1_size, json.dumps(job.content), [], op_kernel_name)
|
||||
job.status = JobStatus.JOB_RUNNING
|
||||
return True
|
||||
|
||||
|
||||
def rl_tune_single_op(job: TbeJob):
|
||||
"""
|
||||
RL tune single op
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
compute_op_info_list = get_compute_op_list(job.content)
|
||||
if len(compute_op_info_list) != 1:
|
||||
job.error("Invalid op compute num ({}) in rl tune single op".format(len(compute_op_info_list)))
|
||||
return False
|
||||
compute_op_info = compute_op_info_list[0]
|
||||
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
||||
op_type = compute_op_info["type"]
|
||||
l1_size = job.content["l1_size"]
|
||||
op_module_name = compute_op_info["module_name"]
|
||||
op_kernel_name = compute_op_info["op_name"]
|
||||
full_name = compute_op_info["name"]
|
||||
py_module_path = compute_op_info["py_module_path"]
|
||||
op_func_name = compute_op_info["func_name"]
|
||||
_normalize_module_name(op_module_name, py_module_path)
|
||||
set_current_op_name(op_kernel_name)
|
||||
unknown_shape = compute_op_info["unknown_shape"]
|
||||
int64_mode = compute_op_info["int64mode"]
|
||||
dynamic_compile_static = compute_op_info["dynamic_compile_static"]
|
||||
op_pattern = compute_op_info["pattern"]
|
||||
fuzz_build_info = get_fuzz_build_info(job.content)
|
||||
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
|
||||
device_id = job.content["SocInfo"]["deviceId"]
|
||||
try:
|
||||
build_single_op_from_c(op_module_name, op_func_name, op_type, "build", unknown_shape,
|
||||
(inputs, outputs, attrs), int64_mode, dynamic_compile_static, op_pattern,
|
||||
auto_tiling_mode, device_id, json.dumps(fuzz_build_info))
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
job.error(
|
||||
"Single op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string))
|
||||
exc_type, exc_value, _ = sys.exc_info()
|
||||
job.error(
|
||||
"exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc()))
|
||||
return False
|
||||
tune_op_module_name = op_module_name + "@" + py_module_path
|
||||
base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o"
|
||||
from schedule_search.rl_online_tune import dispatch_single_tune_task
|
||||
res = dispatch_single_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, full_name,
|
||||
tune_op_module_name, op_func_name, op_type, (inputs, outputs, attrs))
|
||||
res = _process_rl_tune_result(job, res)
|
||||
return res
|
||||
|
||||
|
||||
def rl_tune_fusion_op(job: TbeJob):
|
||||
"""
|
||||
rl tune fusion op
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
op_kernel_name = job.content["fusion_op_name"]
|
||||
set_current_op_name(op_kernel_name)
|
||||
|
||||
try:
|
||||
fusion_op(json.dumps(job.content))
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
job.error(
|
||||
"Fusion op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string))
|
||||
exc_type, exc_value, _ = sys.exc_info()
|
||||
job.error(
|
||||
"exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc()))
|
||||
return False
|
||||
l1_size = job.content["l1_size"]
|
||||
base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o"
|
||||
compute_op_list = get_compute_op_list(job.content)
|
||||
op_module_names_str = ""
|
||||
for op in compute_op_list:
|
||||
op_module_names_str = op_module_names_str + "," + op["module_name"]
|
||||
op_module_names_str = op_module_names_str[1:]
|
||||
from schedule_search.rl_online_tune import dispatch_fusion_tune_task
|
||||
res = dispatch_fusion_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, op_module_names_str,
|
||||
json.dumps(job.content))
|
||||
res = _process_rl_tune_result(job, res)
|
||||
return res
|
||||
|
||||
|
||||
def _process_rl_tune_result(job, res):
|
||||
if not res:
|
||||
res = bool(job.sys_offline_tune or os.getenv("REPEAT_TUNE", "False").lower() != "true")
|
||||
else:
|
||||
job.status = JobStatus.JOB_RUNNING
|
||||
res = True
|
||||
return res
|
||||
|
||||
|
||||
def get_finish_tasks(source_id):
|
||||
"""
|
||||
Get finish task from parallel compilation framework
|
||||
:return task info list
|
||||
"""
|
||||
return get_finished_compilation_task(source_id)
|
||||
|
||||
|
||||
def tbe_finalize(auto_tiling_mode, offline_tune):
|
||||
"""
|
||||
finalize tbe parallel compilation resource
|
||||
:param auto_tiling_mode: RL/GA/RL,GA
|
||||
:param offline_tune: True/False
|
||||
:return: None
|
||||
"""
|
||||
deinit_multi_process_env()
|
||||
if "RL" in auto_tiling_mode or offline_tune:
|
||||
from schedule_search.rl_online_tune import rl_tune_deinit
|
||||
rl_tune_deinit()
|
||||
clear_fusion_params()
|
||||
return True
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""tbe helper to parse json content"""
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
from .tbe_job import JobType
|
||||
|
||||
|
||||
class BuildType(Enum):
|
||||
""" Build Type """
|
||||
INITIALLY = "initially_build"
|
||||
FUZZILY = "fuzzily_build"
|
||||
ACCURATELY = "accurately"
|
||||
|
||||
|
||||
job_type_list = [job_type.value for _, job_type in JobType.__members__.items()]
|
||||
|
||||
|
||||
def check_job_json(job_info):
|
||||
"""
|
||||
Check tne compilation job json's required element
|
||||
:param job_info:tne compilation job json
|
||||
:return: raise value error if wrong
|
||||
"""
|
||||
if 'source_id' not in job_info:
|
||||
raise ValueError("Json string Errors, key:source_id not found.")
|
||||
if 'job_id' not in job_info:
|
||||
raise ValueError("Json string Errors, key:job_id not found.")
|
||||
if 'job_type' not in job_info or not job_info['job_type']:
|
||||
raise ValueError("Json string Errors, key:job_type not found.")
|
||||
if job_info['job_type'] not in job_type_list:
|
||||
raise ValueError("Invalid job type: {}.".format(job_info['job_type']))
|
||||
if 'job_content' not in job_info:
|
||||
raise ValueError("Json string Errors, key:job_content not found.")
|
||||
|
||||
|
||||
def get_soc_info(initialize_job_info):
|
||||
"""
|
||||
Get soc info from initialize job info
|
||||
:param initialize_job_info:
|
||||
:return: soc info
|
||||
"""
|
||||
soc_param = dict()
|
||||
soc_param["op_impl_mode"] = initialize_job_info["SocInfo"]["op_impl_mode"]
|
||||
soc_param["op_debug_level"] = initialize_job_info["SocInfo"]["op_debug_level"]
|
||||
soc_param["op_impl_mode_list"] = initialize_job_info["SocInfo"]["op_impl_mode_list"]
|
||||
soc_param["op_debug_dir"] = initialize_job_info["SocInfo"]["op_debug_dir"]
|
||||
soc_param["vector_fp_ceiling"] = initialize_job_info["SocInfo"]["vector_fp_ceiling"]
|
||||
soc_param['mdl_bank_path'] = initialize_job_info["SocInfo"]["mdl_bank_path"]
|
||||
soc_param['op_bank_path'] = initialize_job_info["SocInfo"]["op_bank_path"]
|
||||
|
||||
soc_info = list()
|
||||
soc_info.append(initialize_job_info["SocInfo"]["socVersion"])
|
||||
soc_info.append(initialize_job_info["SocInfo"]["coreType"])
|
||||
soc_info.append(initialize_job_info["SocInfo"]["coreNum"])
|
||||
soc_info.append(initialize_job_info["SocInfo"]["l1Fusion"])
|
||||
soc_info.append(initialize_job_info["SocInfo"]["l2Mode"])
|
||||
soc_info.append(initialize_job_info["SocInfo"]["l2Fusion"])
|
||||
soc_info.append(soc_param)
|
||||
|
||||
return soc_info
|
||||
|
||||
|
||||
def check_arg_info(io_info):
|
||||
"""
|
||||
Check parameter Validity.
|
||||
:param io_info:A dict, to be checked.
|
||||
:return: Exception: If specific keyword is not found.
|
||||
"""
|
||||
if 'shape' not in io_info:
|
||||
raise ValueError("Json string Errors, key:shape not found.")
|
||||
if 'ori_shape' not in io_info:
|
||||
raise ValueError("Json string Errors, key:ori_shape not found.")
|
||||
if 'format' not in io_info or not io_info['format']:
|
||||
raise ValueError("Json string Errors, key:format not found.")
|
||||
if 'ori_format' not in io_info or not io_info['ori_format']:
|
||||
raise ValueError("Json string Errors, key:ori_format not found.")
|
||||
if 'dtype' not in io_info or not io_info['dtype']:
|
||||
raise ValueError("Json string Errors, key:dtype not found.")
|
||||
if 'param_type' not in io_info or not io_info['param_type']:
|
||||
raise ValueError("Json string Errors, key:param_type not found.")
|
||||
|
||||
|
||||
def get_input_output_args(io_info):
|
||||
"""
|
||||
Get input/output args from io info
|
||||
:param io_info:
|
||||
:return:input/output args
|
||||
"""
|
||||
args = []
|
||||
if io_info is None:
|
||||
return args
|
||||
for item in io_info:
|
||||
if isinstance(item, dict):
|
||||
arg = get_single_io_arg(item)
|
||||
args.append(arg)
|
||||
elif isinstance(item, list):
|
||||
dyn_arg = []
|
||||
for info in item:
|
||||
arg = get_single_io_arg(info)
|
||||
dyn_arg.append(arg)
|
||||
args.append(tuple(dyn_arg))
|
||||
return args
|
||||
|
||||
|
||||
def get_single_io_arg(info):
|
||||
"""
|
||||
Get single input/output arg from io info
|
||||
:param info:
|
||||
:return:input/output arg
|
||||
"""
|
||||
if 'valid' not in info:
|
||||
raise ValueError("Json string Errors, key:valid not found.")
|
||||
if info['valid']:
|
||||
check_arg_info(info)
|
||||
del info['valid']
|
||||
del info['name']
|
||||
res = info
|
||||
else:
|
||||
res = None
|
||||
return res
|
||||
|
||||
|
||||
def assemble_op_args(compute_op_info):
|
||||
"""
|
||||
Assemble op args
|
||||
:param compute_op_info:
|
||||
:return: op args
|
||||
"""
|
||||
inputs_info = compute_op_info["input_desc"] if "input_desc" in compute_op_info.keys() else None
|
||||
outputs_info = compute_op_info["output_desc"] if "output_desc" in compute_op_info.keys() else None
|
||||
attrs = compute_op_info["attr_desc"] if "attr_desc" in compute_op_info.keys() else []
|
||||
inputs = get_input_output_args(inputs_info)
|
||||
outputs = get_input_output_args(outputs_info)
|
||||
attrs.append(compute_op_info["op_name"])
|
||||
return inputs, outputs, attrs
|
||||
|
||||
|
||||
def get_compute_op_list(job_content):
|
||||
"""
|
||||
Get compute op info list from job content info
|
||||
:param job_content: tbe compilation content info
|
||||
:return: compute op info list
|
||||
"""
|
||||
op_list = job_content["op_list"]
|
||||
op_compute_list = []
|
||||
for op in op_list:
|
||||
if op["type"] != "Data":
|
||||
op_compute_list.append(op)
|
||||
return op_compute_list
|
||||
|
||||
|
||||
def get_options_info(job_content):
|
||||
"""
|
||||
Get options info
|
||||
:param job_content:
|
||||
:return: options
|
||||
"""
|
||||
options = dict()
|
||||
options["socVersion"] = job_content["SocInfo"]["socVersion"]
|
||||
options["coreType"] = job_content["SocInfo"]["coreType"]
|
||||
options["coreNum"] = job_content["SocInfo"]["coreNum"]
|
||||
options["l1Fusion"] = job_content["SocInfo"]["l1Fusion"]
|
||||
options["l2Fusion"] = job_content["SocInfo"]["l2Fusion"]
|
||||
options["l2Mode"] = job_content["SocInfo"]["l2Mode"]
|
||||
options["op_debug_level"] = job_content["SocInfo"]["op_debug_level"]
|
||||
options["op_impl_mode"] = job_content["SocInfo"]["op_impl_mode"]
|
||||
options["op_debug_dir"] = job_content["SocInfo"]["op_debug_dir"]
|
||||
options["op_compiler_cache_dir"] = job_content["SocInfo"]["op_compiler_cache_dir"]
|
||||
options["op_compiler_cache_mode"] = job_content["SocInfo"]["op_compiler_cache_mode"]
|
||||
options["mdl_bank_path"] = job_content["SocInfo"]["op_debug_level"]
|
||||
options["op_bank_path"] = job_content["SocInfo"]["op_bank_path"]
|
||||
options["deviceId"] = job_content["SocInfo"]["deviceId"]
|
||||
options["autoTilingMode"] = job_content["SocInfo"]["autoTilingMode"]
|
||||
options["op_impl_mode_list"] = job_content["SocInfo"]["op_impl_mode_list"]
|
||||
return options
|
||||
|
||||
|
||||
def get_fuzz_build_info(job_content):
|
||||
"""
|
||||
Get fuzz build info from job content info
|
||||
:param job_content: job content info
|
||||
:return: fuzz build info
|
||||
"""
|
||||
op_compute_info = get_compute_op_list(job_content)[0]
|
||||
fuzz_build_info = dict()
|
||||
fuzz_build_info["compile_type"] = "fuzzily_build" if op_compute_info["build_type"] == BuildType.FUZZILY.value \
|
||||
else "accurately_build"
|
||||
fuzz_build_info["miss_support_info"] = op_compute_info["miss_support_info"]
|
||||
fuzz_build_info["max_kernel_id"] = op_compute_info["max_kernel_id"]
|
||||
fuzz_build_info["incremental_link"] = os.path.realpath(
|
||||
job_content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_compute_info["name"] + ".json") if \
|
||||
op_compute_info["build_type"] == BuildType.FUZZILY.value else ""
|
||||
return fuzz_build_info
|
||||
|
||||
|
||||
def get_func_names(job_content):
|
||||
"""
|
||||
Get function names from job content json
|
||||
:param job_content: job content info
|
||||
:return: function names
|
||||
"""
|
||||
func_names = []
|
||||
for op in job_content["op_list"]:
|
||||
if "func_name" in op:
|
||||
func_names.append(op["func_name"])
|
||||
return func_names
|
||||
|
||||
|
||||
def adjust_custom_op_info(compute_op_info):
|
||||
"""
|
||||
adjust custom op info
|
||||
:param compute_op_info:
|
||||
:return:
|
||||
"""
|
||||
py_module_path = compute_op_info["py_module_path"]
|
||||
if os.path.isfile(py_module_path):
|
||||
py_module_path, file_name = os.path.split(py_module_path)
|
||||
module_name, _ = os.path.splitext(file_name)
|
||||
compute_op_info["py_module_path"] = py_module_path
|
||||
compute_op_info["module_name"] = module_name
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""tbe compile job definition"""
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class JobType(Enum):
|
||||
""" Job Type """
|
||||
INITIALIZE_JOB = 'Initialize'
|
||||
FINALIZE_JOB = 'Finalize'
|
||||
CHECK_JOB = 'CheckSupport'
|
||||
SELECT_JOB = 'SelectFormat'
|
||||
PRECOMPILE_JOB = 'PreCompile'
|
||||
COMPILE_JOB = 'Compile'
|
||||
TUNE_JOB = 'Tune'
|
||||
QUERY_JOB = 'Query'
|
||||
|
||||
|
||||
class LogLevel(Enum):
|
||||
""" Log Level """
|
||||
DEBUG = 0
|
||||
INFO = 1
|
||||
WARNING = 2
|
||||
ERROR = 3
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
""" Job Status """
|
||||
JOB_INITIAL = "INITIAL"
|
||||
JOB_FAILED = "FAILED"
|
||||
JOB_SUCCESS = "SUCCESS"
|
||||
JOB_RUNNING = "RUNNING"
|
||||
|
||||
|
||||
class LogMessage:
|
||||
""" Log message """
|
||||
|
||||
def __init__(self, index, level, info):
|
||||
self.index = index
|
||||
self.level = level
|
||||
self.info = info
|
||||
|
||||
|
||||
def _get_message(msg, args):
|
||||
"""
|
||||
Return the message for this LogRecord.
|
||||
|
||||
Return the message for this LogRecord after merging any user-supplied
|
||||
arguments with the message.
|
||||
"""
|
||||
msg = str(msg)
|
||||
if args:
|
||||
msg = msg % args
|
||||
return str(datetime.datetime.now()) + ": " + msg
|
||||
|
||||
|
||||
class TbeJob:
|
||||
""" Tbe compilation job """
|
||||
|
||||
def __init__(self, source_id, job_id, job_type, content, json_str, sys_info):
|
||||
self.source_id = source_id
|
||||
self.id = job_id
|
||||
self.type = JobType(job_type)
|
||||
self.status = JobStatus.JOB_INITIAL
|
||||
self.content = content
|
||||
self.result = ""
|
||||
self.process_info = []
|
||||
self.json_string = json_str
|
||||
self._sys_logger = sys_info["logger"]
|
||||
self.sys_offline_tune = sys_info["offline_tune"]
|
||||
self.sys_tune_dump_path = sys_info["tune_dump_path"]
|
||||
self.sys_para_debug_path = sys_info["para_debug_path"]
|
||||
# license info
|
||||
self.rl_tune_switch = sys_info["rl_tune_switch"]
|
||||
self.rl_tune_list = sys_info["rl_tune_list"]
|
||||
self.op_tune_switch = sys_info["op_tune_switch"]
|
||||
self.op_tune_list = sys_info["op_tune_list"]
|
||||
self.pass_list = sys_info["pass_list"]
|
||||
|
||||
def debug(self, msg, *args, **kwargs):
|
||||
"""
|
||||
log debug level info
|
||||
:param msg:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
processed_msg = _get_message(msg, args)
|
||||
message = LogMessage(len(self.process_info), LogLevel.DEBUG, processed_msg)
|
||||
self.process_info.append(message)
|
||||
self._sys_logger.debug(msg, *args, **kwargs)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
"""
|
||||
log info level info
|
||||
:param msg:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
processed_msg = _get_message(msg, args)
|
||||
message = LogMessage(len(self.process_info), LogLevel.INFO, processed_msg)
|
||||
self.process_info.append(message)
|
||||
self._sys_logger.info(msg, *args, **kwargs)
|
||||
|
||||
def warning(self, msg, *args, **kwargs):
|
||||
"""
|
||||
log warning level info
|
||||
:param msg:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
processed_msg = _get_message(msg, args)
|
||||
message = LogMessage(len(self.process_info), LogLevel.WARNING, processed_msg)
|
||||
self.process_info.append(message)
|
||||
self._sys_logger.warning(msg, *args, **kwargs)
|
||||
|
||||
def error(self, msg, *args, **kwargs):
|
||||
"""
|
||||
log error level info
|
||||
:param msg:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
processed_msg = _get_message(msg, args)
|
||||
message = LogMessage(len(self.process_info), LogLevel.ERROR, processed_msg)
|
||||
self.process_info.append(message)
|
||||
self._sys_logger.error(msg, *args, **kwargs)
|
||||
|
||||
def get_result(self):
|
||||
"""
|
||||
Get tht job process result string
|
||||
:return: job process result string
|
||||
"""
|
||||
result = dict()
|
||||
result["status"] = self.status.value
|
||||
result["source_id"] = self.source_id
|
||||
result["job_id"] = self.id
|
||||
result["job_type"] = self.type.value
|
||||
result["result"] = self.result
|
||||
self.debug("Resp result:{}".format(json.dumps(result)))
|
||||
process_info = []
|
||||
for info in self.process_info:
|
||||
msg = {"index": info.index, "level": info.level.value, "message": info.info}
|
||||
process_info.append(msg)
|
||||
result["process_info"] = process_info
|
||||
return json.dumps(result)
|
|
@ -0,0 +1,487 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""tbe job manager"""
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from enum import Enum
|
||||
|
||||
from .tbe_adapter import tbe_initialize, get_auto_tune_support_op_list, tbe_finalize, check_support, select_op_format, \
|
||||
parallel_pre_compile_op, do_fuzz_build_tbe_op, before_build_process, build_single_pre_op, sync_fusion_env, \
|
||||
parallel_compile_fusion_op, rl_tune_single_op, rl_tune_fusion_op, ga_tune, get_finish_tasks, get_prebuild_output
|
||||
from .tbe_helper import check_job_json, get_compute_op_list, get_func_names
|
||||
from .tbe_job import TbeJob, JobStatus, JobType
|
||||
|
||||
|
||||
class TbeJobManager:
|
||||
""" TBE compiler job manager """
|
||||
|
||||
def __init__(self):
|
||||
self.job_handlers = {
|
||||
JobType.INITIALIZE_JOB: self.initialize_handler,
|
||||
JobType.FINALIZE_JOB: self.finalize_handler,
|
||||
JobType.CHECK_JOB: self.check_support_handler,
|
||||
JobType.SELECT_JOB: self.select_format_handler,
|
||||
JobType.PRECOMPILE_JOB: self.pre_compile_handler,
|
||||
JobType.COMPILE_JOB: self.compile_handler,
|
||||
JobType.TUNE_JOB: self.tune_handler,
|
||||
JobType.QUERY_JOB: self.query_handler
|
||||
}
|
||||
|
||||
self._all_jobs = {}
|
||||
self._finished_jobs = {}
|
||||
self._running_jobs = {}
|
||||
self._raw_finish_jobs = {}
|
||||
self.tbe_initialize = False
|
||||
self.para_debug_path = ""
|
||||
self.auto_tiling_mode = ""
|
||||
self.offline_tune = False
|
||||
self.tune_op_list = []
|
||||
self.tune_dump_path = ""
|
||||
self.tune_bank_path = ""
|
||||
self.auto_tune_op_list = []
|
||||
self.pre_build_ops = {}
|
||||
self.fusion_need_sync = 0
|
||||
self.imported_module = {}
|
||||
# license info
|
||||
self.rl_tune_switch = ""
|
||||
self.rl_tune_list = ""
|
||||
self.op_tune_switch = ""
|
||||
self.op_tune_list = ""
|
||||
self.pass_list = ""
|
||||
|
||||
def __del__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the job manager
|
||||
:return: None
|
||||
"""
|
||||
self._all_jobs = {}
|
||||
self._finished_jobs = {}
|
||||
self._running_jobs = {}
|
||||
self._raw_finish_jobs = {}
|
||||
self.para_debug_path = ""
|
||||
self.auto_tiling_mode = ""
|
||||
self.offline_tune = False
|
||||
self.tune_op_list = []
|
||||
self.tune_dump_path = ""
|
||||
self.tune_bank_path = ""
|
||||
self.auto_tune_op_list = []
|
||||
self.pre_build_ops = []
|
||||
self.fusion_need_sync = 0
|
||||
self.imported_module = {}
|
||||
if self.tbe_initialize:
|
||||
tbe_finalize(self.auto_tiling_mode, self.offline_tune)
|
||||
self.tbe_initialize = False
|
||||
|
||||
def job_handler(self, job_str):
|
||||
"""
|
||||
Tbe job handler
|
||||
:param job_str: tbe compile job string
|
||||
:return: job process result json string
|
||||
"""
|
||||
job = None
|
||||
try:
|
||||
job_json = json.loads(job_str)
|
||||
check_job_json(job_json)
|
||||
job_id = job_json["job_id"]
|
||||
source_id = job_json["source_id"]
|
||||
job_type = job_json["job_type"]
|
||||
sys_info = self._get_job_sys_info()
|
||||
job = TbeJob(source_id, job_id, job_type, job_json["job_content"], job_str, sys_info)
|
||||
job.debug("Req job string: {}".format(job_str))
|
||||
post_job(self._all_jobs, job)
|
||||
if not self.tbe_initialize and job.type != JobType.INITIALIZE_JOB:
|
||||
job.error(
|
||||
"Initialize Job should be processed before job {}, job json string:{}".format(job.type,
|
||||
job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
func = self.job_handlers.get(job.type)
|
||||
res = func(job)
|
||||
return res
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
sys_info = self._get_job_sys_info()
|
||||
job = TbeJob(-1, -1, "", None, job_str, sys_info) if job is None else job
|
||||
job.status = JobStatus.JOB_FAILED
|
||||
job.result = "Exception during job process"
|
||||
job.error("Process Job Failed")
|
||||
job.error("Job json string:\n{}\n".format(job_str))
|
||||
job.error("Error message:{}".format(traceback.format_exc()))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
|
||||
def initialize_handler(self, job: TbeJob):
|
||||
""" Initialize job handler """
|
||||
self._init_sys_info(job)
|
||||
res = tbe_initialize(job)
|
||||
if not res:
|
||||
job.error("Process Initialize Job failed, job json string:{}".format(job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
if "GA" in self.auto_tiling_mode:
|
||||
self.auto_tune_op_list = get_auto_tune_support_op_list(job)
|
||||
self.tbe_initialize = True
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
|
||||
|
||||
def finalize_handler(self, job: TbeJob):
|
||||
""" Finalize job handler """
|
||||
if not self.tbe_initialize:
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
|
||||
res = tbe_finalize(self.auto_tiling_mode, self.offline_tune)
|
||||
if not res:
|
||||
job.error("Process Finalize Job failed, job json string:{}".format(job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
|
||||
|
||||
def check_support_handler(self, job: TbeJob):
|
||||
""" Check Support job handler """
|
||||
res = check_support(job)
|
||||
if not res:
|
||||
job.error("Process CheckSupport Job failed, job json string:{}".format(job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
self._update_imported_op_module(job)
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
|
||||
|
||||
def select_format_handler(self, job: TbeJob):
|
||||
""" Select Format job handler """
|
||||
res = select_op_format(job)
|
||||
if not res:
|
||||
job.error("Process SelectFormat Job failed, job json string:{}".format(job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
|
||||
|
||||
def pre_compile_handler(self, job: TbeJob):
|
||||
""" Pre Compile job handler """
|
||||
res = parallel_pre_compile_op(job)
|
||||
if not res:
|
||||
job.error("Process PreCompile Job failed, job json string:{}".format(job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
self.pre_build_ops[job.content["fusion_op_name"]] = job
|
||||
return self.add_to_running_jobs(job)
|
||||
|
||||
def compile_handler(self, job: TbeJob):
|
||||
""" Compile job handler """
|
||||
compute_op_list = get_compute_op_list(job.content)
|
||||
if len(compute_op_list) == 1: # pylint: disable=no-else-return
|
||||
res = do_fuzz_build_tbe_op(job)
|
||||
if not res:
|
||||
job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
if job.result == "NOT_CHANGED":
|
||||
job.result = ""
|
||||
before_build_process(job)
|
||||
res = build_single_pre_op(job)
|
||||
if not res:
|
||||
job.error("Process build single pre op failed, job json string:{}".format(job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
return self.add_to_running_jobs(job)
|
||||
if job.result == "SUCCESS":
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
|
||||
job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
else:
|
||||
before_build_process(job)
|
||||
if self.fusion_need_sync:
|
||||
sync_fusion_env(self.fusion_need_sync, self.imported_module)
|
||||
self.fusion_need_sync = 0
|
||||
res = parallel_compile_fusion_op(job)
|
||||
if not res:
|
||||
job.error("Parallel_compile_fusion_op Job failed, job json string:{}".format(job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
return self.add_to_running_jobs(job)
|
||||
|
||||
def tune_handler(self, job: TbeJob):
|
||||
""" Tune job handler """
|
||||
before_build_process(job)
|
||||
tune_mode = self._select_tune_mode(job)
|
||||
if tune_mode == TuneMode.NO_TUNE:
|
||||
return self.compile_handler(job)
|
||||
compute_op_list = get_compute_op_list(job.content)
|
||||
if len(compute_op_list) == 1:
|
||||
if tune_mode == TuneMode.RL_TUNE:
|
||||
res = rl_tune_single_op(job)
|
||||
else:
|
||||
if self.fusion_need_sync:
|
||||
sync_fusion_env(self.fusion_need_sync, self.imported_module)
|
||||
self.fusion_need_sync = 0
|
||||
res = ga_tune(job)
|
||||
if not res:
|
||||
job.error("ga tune Job failed, job json string:{}".format(job.json_string))
|
||||
return self.compile_handler(job)
|
||||
else:
|
||||
if self.fusion_need_sync:
|
||||
sync_fusion_env(self.fusion_need_sync, self.imported_module)
|
||||
self.fusion_need_sync = 0
|
||||
if tune_mode == TuneMode.RL_TUNE:
|
||||
res = rl_tune_fusion_op(job)
|
||||
else:
|
||||
res = ga_tune(job)
|
||||
if not res:
|
||||
job.error(
|
||||
"Tune Job failed, tune type {}, job json string:{}".format(tune_mode, job.json_string))
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
|
||||
if job.status == JobStatus.JOB_RUNNING:
|
||||
if tune_mode == TuneMode.RL_TUNE and len(compute_op_list) == 1:
|
||||
self._update_imported_op_module(job)
|
||||
return self.add_to_running_jobs(job)
|
||||
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
|
||||
|
||||
def query_handler(self, query_job: TbeJob):
|
||||
""" Query job handler """
|
||||
target_source_id = query_job.content["source_id"]
|
||||
target_job_id = query_job.content["job_id"]
|
||||
target_job = get_job(self._finished_jobs, target_source_id, target_job_id)
|
||||
if target_job:
|
||||
query_job.warning("Query a finished job: {}".format(query_job.content))
|
||||
query_job.result = target_job.get_result()
|
||||
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
|
||||
target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id)
|
||||
if not target_job:
|
||||
self.update_raw_finished_jobs(query_job)
|
||||
target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id)
|
||||
if target_job:
|
||||
query_job.debug("Found job in raw finished jobs, source_id:{}, job_id:{}".format(target_source_id,
|
||||
target_job_id))
|
||||
query_job.result = target_job.get_result()
|
||||
del_job(self._raw_finish_jobs, target_job.source_id, target_job.id)
|
||||
self.add_to_finished_jobs(target_job, target_job.status)
|
||||
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
|
||||
target_job = get_job(self._running_jobs, target_source_id, target_job_id)
|
||||
if target_job:
|
||||
query_job.debug("Found job in Running jobs, source_id:{}, job_id:{}".format(target_source_id,
|
||||
target_job_id))
|
||||
target_job.debug("Be Queried")
|
||||
query_job.result = target_job.get_result()
|
||||
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
|
||||
target_job = get_job(self._all_jobs, target_source_id, target_job_id)
|
||||
if target_job:
|
||||
query_job.debug("Found job in all jobs, source_id:{}, job_id:{}".format(target_source_id,
|
||||
target_job_id))
|
||||
target_job.debug("Be Queried")
|
||||
query_job.result = target_job.get_result()
|
||||
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
|
||||
query_job.error("Can't find job in finished/raw_finished/running jobs, source_id: {}".format(target_source_id))
|
||||
query_job.result = ""
|
||||
return self.add_to_finished_jobs(query_job, JobStatus.JOB_FAILED)
|
||||
|
||||
def _get_job_sys_info(self):
|
||||
"""
|
||||
Get job manager system info
|
||||
:return: system info
|
||||
"""
|
||||
sys_info = dict()
|
||||
sys_info["logger"] = DummyLogger
|
||||
sys_info["para_debug_path"] = self.para_debug_path
|
||||
sys_info["tune_dump_path"] = self.tune_dump_path
|
||||
sys_info["offline_tune"] = self.offline_tune
|
||||
# license info
|
||||
sys_info["rl_tune_switch"] = self.rl_tune_switch
|
||||
sys_info["rl_tune_list"] = self.rl_tune_list
|
||||
sys_info["op_tune_switch"] = self.op_tune_switch
|
||||
sys_info["op_tune_list"] = self.op_tune_list
|
||||
sys_info["pass_list"] = self.pass_list
|
||||
return sys_info
|
||||
|
||||
def _init_sys_info(self, initialize_job):
|
||||
"""
|
||||
Initialize job manager system info from INITIALIZE JOB
|
||||
:param initialize_job: initialize job
|
||||
:return: None
|
||||
"""
|
||||
# auto tune info
|
||||
self.auto_tiling_mode = initialize_job.content["SocInfo"]["autoTilingMode"]
|
||||
self.offline_tune = initialize_job.content["SocInfo"]["offlineTune"]
|
||||
self.tune_op_list = initialize_job.content["TuneInfo"]["tune_op_list"]
|
||||
self.tune_dump_path = initialize_job.content["TuneInfo"]["tune_dump_path"]
|
||||
self.tune_bank_path = initialize_job.content["TuneInfo"]["tune_bank_path"]
|
||||
self.para_debug_path = initialize_job.content["para_debug_path"]
|
||||
# license info
|
||||
self.rl_tune_switch = initialize_job.content["LicInfo"]["rl_tune_switch"]
|
||||
self.rl_tune_list = initialize_job.content["LicInfo"]["rl_tune_list"]
|
||||
self.op_tune_switch = initialize_job.content["LicInfo"]["op_tune_switch"]
|
||||
self.op_tune_list = initialize_job.content["LicInfo"]["op_tune_list"]
|
||||
self.pass_list = initialize_job.content["LicInfo"]["pass_list"]
|
||||
|
||||
def _update_imported_op_module(self, job):
|
||||
"""
|
||||
update imported op module info according to new job
|
||||
:param job:
|
||||
:return:
|
||||
"""
|
||||
compute_op_info = get_compute_op_list(job.content)[0]
|
||||
op_module_name = compute_op_info["module_name"]
|
||||
if op_module_name in self.imported_module.keys():
|
||||
self.imported_module[op_module_name] = self.imported_module[op_module_name] + 1
|
||||
else:
|
||||
self.imported_module[op_module_name] = 1
|
||||
self.fusion_need_sync = self.fusion_need_sync + 1
|
||||
|
||||
def _select_tune_mode(self, job):
|
||||
"""
|
||||
Select the corresponding tune mode according to op job content and job manager system info
|
||||
:param job: tbe tune job
|
||||
:return: NO_TUNE RL_TUNE or GA_TUNE
|
||||
"""
|
||||
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
|
||||
offline_tune = job.content["SocInfo"]["offlineTune"]
|
||||
full_name = job.content["full_name"]
|
||||
func_names = get_func_names(job.content)
|
||||
if self.tune_op_list and full_name not in self.tune_op_list:
|
||||
return TuneMode.NO_TUNE
|
||||
if offline_tune:
|
||||
return TuneMode.RL_TUNE
|
||||
if TuneMode.GA_TUNE.value in auto_tiling_mode:
|
||||
for func_name in func_names:
|
||||
if func_name in self.auto_tune_op_list:
|
||||
return TuneMode.GA_TUNE
|
||||
if TuneMode.RL_TUNE.value in auto_tiling_mode:
|
||||
return TuneMode.RL_TUNE
|
||||
return TuneMode.NO_TUNE
|
||||
|
||||
def update_raw_finished_jobs(self, query_job: TbeJob):
|
||||
"""
|
||||
Get new finished jobs from tbe parallel compilation and add them to raw_finished_jobs
|
||||
:param query_job: query job
|
||||
:return: Node
|
||||
"""
|
||||
new_finished_jobs = get_finish_tasks(query_job.source_id)
|
||||
for new_job in new_finished_jobs:
|
||||
source_id = new_job["graph_id"]
|
||||
job_id = new_job["task_id"]
|
||||
target_job = get_job(self._running_jobs, source_id, job_id)
|
||||
if not target_job:
|
||||
query_job.error("Can't get job, source id:{}, job id:{}".format(source_id, job_id))
|
||||
continue
|
||||
target_job.result = new_job["op_res"] if "op_res" in new_job else new_job["result"]
|
||||
if target_job.type == JobType.PRECOMPILE_JOB:
|
||||
op_name = target_job.content["fusion_op_name"]
|
||||
op_params = get_prebuild_output(op_name)
|
||||
pre_compile_result = dict()
|
||||
pre_compile_result["op_pattern"] = target_job.result
|
||||
pre_compile_result["op_params"] = op_params
|
||||
target_job.result = json.dumps(pre_compile_result)
|
||||
target_job.info("Query result:{}".format(new_job["result"]))
|
||||
if new_job["status_code"] == 0:
|
||||
target_job.status = JobStatus.JOB_SUCCESS
|
||||
target_job.info("Query info_msg:{}".format(new_job["info_msg"]))
|
||||
else:
|
||||
target_job.status = JobStatus.JOB_FAILED
|
||||
target_job.error("Query info_msg:{}".format(new_job["info_msg"]))
|
||||
if "err_args" in new_job:
|
||||
target_job.error("Query err_args:{}".format(new_job["err_args"]))
|
||||
if "except_msg" in new_job:
|
||||
target_job.error("Query except_msg:{}".format(new_job["except_msg"]))
|
||||
if "except_tuple_msg" in new_job:
|
||||
target_job.error("Query except_tuple_msg:{}".format(new_job["except_tuple_msg"]))
|
||||
target_job.error("\nOriginal compile json: \n {}\n".format(target_job.json_string))
|
||||
post_job(self._raw_finish_jobs, target_job)
|
||||
del_job(self._running_jobs, target_job.source_id, target_job.id)
|
||||
|
||||
def add_to_finished_jobs(self, job, status):
|
||||
"""
|
||||
add job to finished jobs with job process status
|
||||
:param job:
|
||||
:param status:
|
||||
:return: job process result string
|
||||
"""
|
||||
job.status = status
|
||||
post_job(self._finished_jobs, job)
|
||||
return job.get_result()
|
||||
|
||||
def add_to_running_jobs(self, job):
|
||||
"""
|
||||
add job to running jobs
|
||||
:param job:
|
||||
:return: job process result string
|
||||
"""
|
||||
job.status = JobStatus.JOB_RUNNING
|
||||
post_job(self._running_jobs, job)
|
||||
return job.get_result()
|
||||
|
||||
|
||||
class TuneMode(Enum):
|
||||
NO_TUNE = "NO_TUNE"
|
||||
GA_TUNE = "GA"
|
||||
RL_TUNE = "RL"
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
"""DummyLogger"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def debug(msg, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def info(msg, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def warning(msg, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def error(msg, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def get_job(jobs, source_id, job_id):
|
||||
"""
|
||||
get the job from job list according to source_id and job_id
|
||||
:param jobs: job list
|
||||
:param source_id : target job's source_id
|
||||
:param job_id: target job's job_id
|
||||
:return: job instance if found in job list
|
||||
None if not found in job list
|
||||
"""
|
||||
if source_id not in jobs.keys():
|
||||
return None
|
||||
if job_id not in jobs[source_id].keys():
|
||||
return None
|
||||
return jobs[source_id][job_id]
|
||||
|
||||
|
||||
def post_job(jobs, new_job):
|
||||
"""
|
||||
add the new job into jobs list
|
||||
:param jobs: job list
|
||||
:param new_job : new job
|
||||
:return: None
|
||||
"""
|
||||
if new_job.source_id not in jobs.keys():
|
||||
jobs[new_job.source_id] = dict()
|
||||
jobs[new_job.source_id][new_job.id] = new_job
|
||||
else:
|
||||
jobs[new_job.source_id][new_job.id] = new_job
|
||||
|
||||
|
||||
def del_job(jobs, source_id, job_id):
|
||||
"""
|
||||
delete the job from job list according to source_id and job_id
|
||||
:param jobs: job list
|
||||
:param source_id : target job's source_id
|
||||
:param job_id: target job's job_id
|
||||
:return: bool True or False
|
||||
"""
|
||||
if source_id not in jobs.keys():
|
||||
return False
|
||||
if job_id not in jobs[source_id].keys():
|
||||
return False
|
||||
del jobs[source_id][job_id]
|
||||
return True
|
|
@ -15,9 +15,11 @@
|
|||
"""kernel build server for ascend"""
|
||||
import sys
|
||||
import warnings
|
||||
from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder
|
||||
from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format
|
||||
|
||||
from mindspore._extends.parallel_compile.tbe_compiler.tbe_job_manager import TbeJobManager
|
||||
from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import check_supported
|
||||
from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format
|
||||
from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder
|
||||
|
||||
|
||||
class TbeBuilder:
|
||||
|
@ -25,6 +27,7 @@ class TbeBuilder:
|
|||
|
||||
def __init__(self):
|
||||
self.tbe_builder = create_tbe_parallel_process()
|
||||
self.tbe_job_manager = TbeJobManager()
|
||||
|
||||
def init_auto_tune_env(self, mode):
|
||||
return self.tbe_builder.init_auto_tune_env(mode)
|
||||
|
@ -43,6 +46,10 @@ class TbeBuilder:
|
|||
|
||||
def exit(self):
|
||||
self.tbe_builder.exit()
|
||||
self.tbe_job_manager.reset()
|
||||
|
||||
def job_process(self, json):
|
||||
return self.tbe_job_manager.job_handler(json)
|
||||
|
||||
|
||||
class AscendMessager(Messager):
|
||||
|
@ -75,6 +82,11 @@ class AscendMessager(Messager):
|
|||
json = self.get_message()
|
||||
res = self.tbe_builder.start(json)
|
||||
self.send_res(res)
|
||||
elif arg == 'TBE/JOB':
|
||||
self.send_ack()
|
||||
json = self.get_message()
|
||||
res = self.tbe_builder.job_process(json)
|
||||
self.send_res(res)
|
||||
elif arg == 'TBE/WAIT':
|
||||
self.send_ack()
|
||||
task_id, res, pre = self.tbe_builder.wait()
|
||||
|
|
|
@ -114,6 +114,21 @@ int AscendKernelBuildClient::TbeStart(const std::string &json, const std::string
|
|||
return std::stoi(res);
|
||||
}
|
||||
|
||||
std::string AscendKernelBuildClient::TbeSendJob(const std::string &json) {
|
||||
auto res = SendRequest(kTbeJob);
|
||||
if (res != kAck) {
|
||||
MS_LOG(ERROR) << "Send TBE job failed, res: " << res;
|
||||
return "";
|
||||
}
|
||||
// Send the json data.
|
||||
res = SendRequest(json);
|
||||
if (res == kFailed) {
|
||||
MS_LOG(ERROR) << "Send TBE job json failed, res: " << res;
|
||||
return "";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) {
|
||||
// Start waiting..
|
||||
auto res = SendRequest(kTbeWait);
|
||||
|
|
|
@ -206,6 +206,7 @@ class AscendKernelBuildClient : public KernelBuildClient {
|
|||
constexpr inline static auto kTbeWait = "TBE/WAIT";
|
||||
constexpr inline static auto kTbeReset = "TBE/RESET";
|
||||
constexpr inline static auto kTbeTune = "TBE/TUNE";
|
||||
constexpr inline static auto kTbeJob = "TBE/JOB";
|
||||
|
||||
// Send server info. query to server
|
||||
constexpr inline static auto kFormat = "FORMAT";
|
||||
|
@ -228,6 +229,7 @@ class AscendKernelBuildClient : public KernelBuildClient {
|
|||
bool CheckSupported(const std::string &json);
|
||||
|
||||
// Run TBE building.
|
||||
std::string TbeSendJob(const std::string &json);
|
||||
int TbeStart(const std::string &json, const std::string &mode);
|
||||
bool TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result);
|
||||
void TbeReset();
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
{
|
||||
"job_content": {
|
||||
"SocInfo": {
|
||||
"autoTilingMode": "NO_TUNE",
|
||||
"coreNum": "",
|
||||
"coreType": "",
|
||||
"deviceId": "1",
|
||||
"l1Fusion": "false",
|
||||
"l2Fusion": "false",
|
||||
"l2Mode": "2",
|
||||
"mdl_bank_path": "",
|
||||
"offlineTune": false,
|
||||
"op_bank_path": "",
|
||||
"op_bank_update": false,
|
||||
"op_compiler_cache_dir": "",
|
||||
"op_compiler_cache_mode": 0,
|
||||
"op_debug_dir": "./",
|
||||
"op_debug_level": "0",
|
||||
"op_impl_mode": "",
|
||||
"op_impl_mode_list": [],
|
||||
"socVersion": "Ascend910A",
|
||||
"vector_fp_ceiling": ""
|
||||
},
|
||||
"TuneInfo": {
|
||||
"tune_bank_path": "",
|
||||
"tune_dump_path": "",
|
||||
"tune_op_list": []
|
||||
},
|
||||
"LicInfo": {
|
||||
"rl_tune_switch": "on",
|
||||
"rl_tune_list": "ALL",
|
||||
"op_tune_switch": "on",
|
||||
"op_tune_list": "ALL",
|
||||
"pass_list": "ALL"
|
||||
},
|
||||
"enable_event": false,
|
||||
"log_level": 1,
|
||||
"para_debug_path": "",
|
||||
"process_num": 8,
|
||||
"tbe_impl_path": "/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe"
|
||||
},
|
||||
"job_id": 1,
|
||||
"job_type": "Initialize",
|
||||
"source_id": 1
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
{}
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import json
|
||||
import time
|
||||
|
||||
from mindspore._extends.parallel_compile.tbe_compiler.tbe_job_manager import TbeJobManager
|
||||
|
||||
MAX_COMPILE_SECONDS = 400
|
||||
QUERY_INTERVAL = 10
|
||||
|
||||
|
||||
def test_parallel_compilation(compile_job_json_str):
|
||||
with open("Initialize.info", 'r') as init_json_file:
|
||||
# Initialize
|
||||
init_job_json = json.load(init_json_file)
|
||||
tbe_compiler = TbeJobManager()
|
||||
res = tbe_compiler.job_handler(json.dumps(init_job_json))
|
||||
print("Initialize result:" + res)
|
||||
res_json = json.loads(res)
|
||||
for item in res_json["process_info"]:
|
||||
print("### LogLevel:" + str(item["level"]) + " " + item["message"])
|
||||
if res_json["status"] == "FAILED":
|
||||
print("Initialize Failed")
|
||||
return False
|
||||
|
||||
print("\n################# Initialize Success #################\n")
|
||||
# Dispatch Compile Job
|
||||
res = tbe_compiler.job_handler(compile_job_json_str)
|
||||
print("Compile result:" + res)
|
||||
compile_result_json = json.loads(res)
|
||||
source_id = compile_result_json["source_id"]
|
||||
job_id = compile_result_json["job_id"]
|
||||
if compile_result_json["status"] != "RUNNING":
|
||||
# Process Finish Job
|
||||
print("Final Compile Result:{}".format(json.dumps(compile_result_json["result"])))
|
||||
print("Process Logs:")
|
||||
for item in compile_result_json["process_info"]:
|
||||
print("### LogLevel:" + str(item["level"]) + " " + item["message"])
|
||||
res_json = json.loads(res)
|
||||
if res_json["status"] == "FAILED":
|
||||
print("Compile Failed")
|
||||
return False
|
||||
else:
|
||||
# Process Running Job
|
||||
print("Query Running job with max compile seconds {}".format(MAX_COMPILE_SECONDS))
|
||||
job_id = job_id + 1
|
||||
query_job_json = dict()
|
||||
query_job_json["source_id"] = source_id
|
||||
query_job_json["job_id"] = job_id
|
||||
query_job_json["job_type"] = "Query"
|
||||
target_job = dict()
|
||||
target_job["source_id"] = source_id
|
||||
target_job["job_id"] = compile_result_json["job_id"]
|
||||
query_job_json["job_content"] = target_job
|
||||
repeat_time = 0
|
||||
while True:
|
||||
print("Dispatch a Query Job")
|
||||
res = tbe_compiler.job_handler(json.dumps(query_job_json))
|
||||
res_json = json.loads(res)
|
||||
print("Query result:{}".format(res))
|
||||
if res_json["status"] == "SUCCESS":
|
||||
print("Target Job info :{}".format(res_json["result"]))
|
||||
target_job = json.loads(res_json["result"])
|
||||
if target_job["status"] == "RUNNING":
|
||||
job_id = job_id + 1
|
||||
query_job_json["job_id"] = query_job_json["job_id"] + 1
|
||||
for item in res_json["process_info"]:
|
||||
print("### LogLevel:" + str(item["level"]) + " " + item["message"])
|
||||
repeat_time = repeat_time + 1
|
||||
if repeat_time > MAX_COMPILE_SECONDS / QUERY_INTERVAL:
|
||||
print("Query TimeOut")
|
||||
print("\n################# Compile Failed #################\n")
|
||||
break
|
||||
print("Sleep {} seconds".format(QUERY_INTERVAL))
|
||||
time.sleep(QUERY_INTERVAL)
|
||||
else:
|
||||
print("\n $$$Final Compile Result:{}\n".format(json.dumps(target_job["result"])))
|
||||
print("Process Logs:")
|
||||
for item in res_json["process_info"]:
|
||||
print("### LogLevel:" + str(item["level"]) + " " + item["message"])
|
||||
print("Target Job Process Logs:")
|
||||
for item in target_job["process_info"]:
|
||||
print("### LogLevel:" + str(item["level"]) + " " + item["message"])
|
||||
if target_job["status"] == "SUCCESS":
|
||||
print("\n################# Compile Success #################\n")
|
||||
else:
|
||||
print("\n################# Compile Failed #################\n")
|
||||
break
|
||||
else:
|
||||
print("Final Compile Failed:{}".format(res))
|
||||
print("Process Logs:")
|
||||
for item in res_json["process_info"]:
|
||||
print("### LogLevel:" + str(item["level"]) + " " + item["message"])
|
||||
print("\n################# Compile Failed #################\n")
|
||||
break
|
||||
|
||||
# Finalize Job
|
||||
job_id = job_id + 1
|
||||
finalize_job_json = dict()
|
||||
finalize_job_json["source_id"] = source_id
|
||||
finalize_job_json["job_id"] = job_id
|
||||
finalize_job_json["job_type"] = "Finalize"
|
||||
finalize_job_json["job_content"] = dict()
|
||||
res = tbe_compiler.job_handler(json.dumps(finalize_job_json))
|
||||
print("Finalize result:{}".format(res))
|
||||
res_json = json.loads(res)
|
||||
if res_json["status"] == "Failed":
|
||||
print("\n################# Finalize Failed #################\n")
|
||||
return False
|
||||
print("\n################# Finalize Success #################\n")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("op.info", "r") as op_json_file:
|
||||
op_json = json.load(op_json_file)
|
||||
test_parallel_compilation(json.dumps(op_json))
|
Loading…
Reference in New Issue