This fixes an issue about mindspore process cannot exit when calling python api op_select_format failed in select kernel steps.
Previously function op_select_format and check_supported raise an exception directly on the tbe_process python side, but we don't deal with the exception, and raise an exeception on c++ side to frontend ME, that will cause some conflict when recycle resource on ME and tbe_process python interpreter. This changes adding try...catch in function op_select_format and check_supported on the python side, and return the Exception string to c++ side, so that we can raise an exception to frontend ME and ME will deal with resouce clearning and exit.
This commit is contained in:
parent
c24252b2cc
commit
5a00d8cb58
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""tbe process"""
|
||||
import sys
|
||||
import os
|
||||
from .common import get_args, get_build_in_impl_path, TBEException
|
||||
|
||||
build_in_impl_path = get_build_in_impl_path()
|
||||
|
||||
|
||||
def _op_select_format(kernel_info):
|
||||
"""
|
||||
call op's op_select_format to get op supported format
|
||||
|
||||
Args:
|
||||
kernel_info (dict): kernel info load by json string
|
||||
|
||||
Returns:
|
||||
op supported format
|
||||
"""
|
||||
try:
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
if impl_path not in sys.path:
|
||||
sys.path.insert(0, impl_path)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "op_select_format"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "op_select_format", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
|
||||
except Exception as e:
|
||||
raise TBEException(str(e))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def _check_supported(kernel_info):
|
||||
"""
|
||||
call op's check_supported to check supported or not
|
||||
|
||||
Args:
|
||||
kernel_info (dict): kernel info load by json string
|
||||
|
||||
Returns:
|
||||
bool: check result, true or false
|
||||
"""
|
||||
try:
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
if impl_path not in sys.path:
|
||||
sys.path.insert(0, impl_path)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "check_supported"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "check_supported", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
|
||||
except Exception as e:
|
||||
raise TBEException(str(e))
|
||||
|
||||
return ret
|
|
@ -19,10 +19,8 @@ import subprocess
|
|||
import sys
|
||||
import os
|
||||
import json
|
||||
from .common import check_kernel_info, get_args, get_build_in_impl_path
|
||||
|
||||
build_in_impl_path = get_build_in_impl_path()
|
||||
|
||||
from .common import check_kernel_info, TBEException
|
||||
from .helper import _op_select_format, _check_supported
|
||||
|
||||
def create_tbe_parallel_compiler():
|
||||
"""
|
||||
|
@ -41,40 +39,17 @@ def op_select_format(op_json: str):
|
|||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
op supported format
|
||||
op supported format or exception message
|
||||
"""
|
||||
ret = ""
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
try:
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
ret = _op_select_format(kernel_info)
|
||||
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
sys.path.insert(0, impl_path)
|
||||
except TBEException as e:
|
||||
return "TBEException: " + str(e)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "op_select_format"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "op_select_format", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -86,40 +61,18 @@ def check_supported(op_json: str):
|
|||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
true or false
|
||||
bool: check result, true or false
|
||||
str: exception message when catch an Exception
|
||||
"""
|
||||
ret = ""
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
try:
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
ret = _check_supported(kernel_info)
|
||||
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
sys.path.insert(0, impl_path)
|
||||
except TBEException as e:
|
||||
return "TBEException: " + str(e)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "check_supported"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "check_supported", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -149,12 +102,12 @@ class CompilerPool:
|
|||
"""compiler pool"""
|
||||
|
||||
def __init__(self):
|
||||
processes = multiprocessing.cpu_count()
|
||||
self.__processe_num = multiprocessing.cpu_count()
|
||||
# max_processes_num: Set the maximum number of concurrent processes for compiler
|
||||
max_processes_num = 16
|
||||
if processes > max_processes_num:
|
||||
processes = max_processes_num
|
||||
self.__pool = multiprocessing.Pool(processes=processes)
|
||||
if self.__processe_num > max_processes_num:
|
||||
self.__processe_num = max_processes_num
|
||||
self.__pool = None
|
||||
self.__next_task_id = 1
|
||||
self.__running_tasks = []
|
||||
|
||||
|
@ -165,11 +118,10 @@ class CompilerPool:
|
|||
del self.__pool
|
||||
|
||||
def exit(self):
|
||||
return
|
||||
# self.__pool.terminate()
|
||||
# self.__pool.join()
|
||||
# if self.__pool is not None:
|
||||
# del self.__pool
|
||||
if self.__pool is not None:
|
||||
self.__pool.terminate()
|
||||
self.__pool.join()
|
||||
del self.__pool
|
||||
|
||||
def start_compile_op(self, op_json):
|
||||
"""
|
||||
|
@ -183,6 +135,8 @@ class CompilerPool:
|
|||
"""
|
||||
task_id = self.__next_task_id
|
||||
self.__next_task_id = self.__next_task_id + 1
|
||||
if self.__pool is None:
|
||||
self.__pool = multiprocessing.Pool(processes=self.__processe_num)
|
||||
task_future = self.__pool.apply_async(func=run_compiler, args=(op_json,))
|
||||
self.__running_tasks.append((task_id, task_future))
|
||||
return task_id
|
||||
|
|
|
@ -98,7 +98,7 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) {
|
|||
*func_name = name_tmp;
|
||||
auto iter = tbe_func_adapter_map.find(*func_name);
|
||||
if (iter != tbe_func_adapter_map.end()) {
|
||||
MS_LOG(INFO) << "map actual op fron me " << func_name << "to tbe op" << iter->second;
|
||||
MS_LOG(INFO) << "map actual op from me " << func_name << "to tbe op" << iter->second;
|
||||
*func_name = iter->second;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,6 +35,8 @@ namespace kernel {
|
|||
constexpr auto kName = "name";
|
||||
constexpr auto kDtype = "dtype";
|
||||
constexpr auto kFormat = "format";
|
||||
constexpr auto kPrefixInput = "input";
|
||||
constexpr auto kPrefixOutput = "output";
|
||||
const std::map<std::string, std::string> DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"},
|
||||
{"NHWC", "DefaultFormat"},
|
||||
{"ND", "DefaultFormat"},
|
||||
|
@ -146,13 +148,13 @@ bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector<std::shared_
|
|||
if (!CheckJsonItemValidity(json_obj, key_name, keys)) {
|
||||
return false;
|
||||
}
|
||||
if (key_name.find("input", 0) != std::string::npos) {
|
||||
if (key_name.compare(0, strlen(kPrefixInput), kPrefixInput) == 0) {
|
||||
std::shared_ptr<OpIOInfo> input = std::make_shared<OpIOInfo>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
input->set_name(json_obj[key_name].at(kName));
|
||||
ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input);
|
||||
inputs->emplace_back(input);
|
||||
} else if (key_name.find("output", 0) != std::string::npos) {
|
||||
} else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) {
|
||||
std::shared_ptr<OpIOInfo> output = std::make_shared<OpIOInfo>();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
output->set_name(json_obj[key_name].at(kName));
|
||||
|
|
|
@ -26,6 +26,7 @@ constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_comp
|
|||
constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler";
|
||||
constexpr auto kOpSelectFormatFunc = "op_select_format";
|
||||
constexpr auto kCheckSupportedFunc = "check_supported";
|
||||
constexpr auto kTBEException = "TBEException";
|
||||
|
||||
PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr;
|
||||
PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr;
|
||||
|
@ -133,6 +134,10 @@ std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) {
|
|||
char *pstr = nullptr;
|
||||
(void)PyArg_Parse(pRet, "s", &pstr);
|
||||
res_json_str = pstr;
|
||||
if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) {
|
||||
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str
|
||||
<< " ,function args:" << PyObjectToStr(pArg);
|
||||
}
|
||||
return res_json_str;
|
||||
}
|
||||
|
||||
|
@ -167,7 +172,18 @@ bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) {
|
|||
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc
|
||||
<< "], function args: " << PyObjectToStr(pArg);
|
||||
}
|
||||
ret = PyObject_IsTrue(pRes) != 0;
|
||||
if (PyBool_Check(pRes)) {
|
||||
ret = PyObject_IsTrue(pRes) != 0;
|
||||
} else {
|
||||
char *pstr = nullptr;
|
||||
(void)PyArg_Parse(pRes, "s", &pstr);
|
||||
std::string res_str = pstr;
|
||||
if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) {
|
||||
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str
|
||||
<< ", function args: " << PyObjectToStr(pArg);
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue