From dec3b5490418d14b58dfa1672bcef513792bfa16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=B8=85=E5=8D=8E?= Date: Wed, 22 Feb 2023 15:26:43 +0800 Subject: [PATCH] [JIT Fallback] Support JIT Fallback Resolve routine; Refactoring exception re-throw routine; Throw exception if there's nested JIT execution. --- .jenkins/check/config/whitelizard.txt | 3 +- .../optimizer/irpass/branch_culling.h | 2 +- .../ccsrc/include/common/utils/python_utils.h | 32 ++ mindspore/ccsrc/pipeline/jit/fallback.cc | 3 +- mindspore/ccsrc/pipeline/jit/parse/parse.cc | 10 +- mindspore/ccsrc/pipeline/jit/parse/resolve.cc | 21 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 207 +++++------- mindspore/ccsrc/pipeline/jit/pipeline.h | 28 +- .../pipeline/jit/static_analysis/prim.cc | 18 +- .../jit/static_analysis/static_analysis.cc | 8 +- .../pipeline/pynative/pynative_execute.cc | 49 +-- mindspore/ccsrc/pybind_api/ir/py_execute_py.h | 105 +----- mindspore/ccsrc/utils/python_utils.cc | 307 ++++++++++++++++++ mindspore/ccsrc/utils/utils.cc | 2 + .../python/mindspore/_extends/parse/parser.py | 2 +- .../fallback/test_graph_fallback_runtime.py | 64 ++++ 16 files changed, 579 insertions(+), 282 deletions(-) create mode 100644 mindspore/ccsrc/include/common/utils/python_utils.h create mode 100644 mindspore/ccsrc/utils/python_utils.cc diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index f425277da3f..1ef6537bc0e 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -18,10 +18,9 @@ mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspo mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic mindspore/mindspore/ccsrc/pipeline/jit/init.cc:PYBIND11_MODULE mindspore/mindspore/ccsrc/pipeline/jit/parse/resolve.cc:mindspore::parse::ResolveObjectToNode -mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile +mindspore/mindspore/ccsrc/utils/python_utils.cc:mindspore::HandleExceptionRethrow mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstract::ConvertAbstractToPython mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.h:mindspore::PyExceptionInitializer::HandleExceptionPy -mindspore/mindspore/ccsrc/pybind_api/ir/py_execute_py.h:mindspore::PyExecuteInitializer::CppInferShapeAndTypePy mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc:aicpu::ARMDropOutGenMaskKernel diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h index 1a325b31bc8..7f7e75c063d 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h @@ -58,7 +58,7 @@ class SwitchSimplify : public OptimizerCaller { << " not support this condition value: " << value_ptr->ToString(); } - MS_LOG(DEBUG) << "condition value: " << value_ptr->ToString() << " bool:" << cond_value; + MS_LOG(DEBUG) << "condition value: " << value_ptr->ToString() << ", cond: " << cond_value; if (cond_value) { return true_br.GetNode(node); } diff --git a/mindspore/ccsrc/include/common/utils/python_utils.h b/mindspore/ccsrc/include/common/utils/python_utils.h new file mode 100644 index 00000000000..b655b870a35 --- /dev/null +++ b/mindspore/ccsrc/include/common/utils/python_utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_UTILS_H_ +#define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_UTILS_H_ + +#include + +#include "include/common/visible.h" +#include "utils/trace_base.h" + +namespace mindspore { +COMMON_EXPORT void HandleExceptionRethrow(const std::function &main_func, + const std::function &already_set_error_handler, + const std::function &other_error_handler, + const std::function &default_error_handler, + const DebugInfoPtr &debug_info = nullptr); +} // namespace mindspore +#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_UTILS_H_ diff --git a/mindspore/ccsrc/pipeline/jit/fallback.cc b/mindspore/ccsrc/pipeline/jit/fallback.cc index 061e1a84e06..0d2c36edc52 100644 --- a/mindspore/ccsrc/pipeline/jit/fallback.cc +++ b/mindspore/ccsrc/pipeline/jit/fallback.cc @@ -58,7 +58,8 @@ AnfNodePtr ConvertInterpretedObjectToPyExecute(const FuncGraphPtr &fg, const Val NewValueNode(std::make_shared(keys)), NewValueNode(std::make_shared(values))}); constexpr auto debug_recursive_level = 2; - MS_LOG(DEBUG) << "interpreted_cnode: " << interpreted_cnode->DebugString(debug_recursive_level); + MS_LOG(DEBUG) << "original node: " << node->DebugString(debug_recursive_level) + << ", interpreted_cnode: " << interpreted_cnode->DebugString(debug_recursive_level); interpreted_cnode->set_debug_info(node->debug_info()); return interpreted_cnode; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index eaeaed0ed1a..4ee2a36b878 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -2840,7 +2840,10 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod const auto &iter = values.find(id_str); if (iter != values.end()) { (void)filter_keys.emplace_back(keys[iter->first]); - (void)filter_values.emplace_back(iter->second); + auto &val_node = iter->second; + // '__py_interpret_local_value_flag__' is used by 'ConvertInterpretedObjForResolve' not to convert PyExecute. + val_node->set_user_data("__py_interpret_local_value_flag__", std::make_shared(true)); + (void)filter_values.emplace_back(val_node); } } @@ -2848,11 +2851,12 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod // Update the valued node if it need interpreting. constexpr int recursive_level = 2; MS_EXCEPTION_IF_NULL(block->func_graph()); + AnfNodePtr interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node); MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: `" << script_text << "`,\nvalue_node: " << value_node->DebugString(recursive_level) << ",\nglobal_dict_node: " << global_dict_node->ToString() - << ",\nlocal_dict_node: " << local_dict_node->DebugString(recursive_level); - AnfNodePtr interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node); + << ",\nlocal_dict_node: " << local_dict_node->DebugString(recursive_level) + << ",\ninterpreted_node: " << interpreted_node->DebugString(recursive_level); // Print a hint for user. auto line_info = trace::GetDebugInfo(value_node->debug_info()); diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 842c2861195..7db20f551dc 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -25,6 +25,7 @@ #include "ir/param_info.h" #include "ir/value.h" #include "ir/map_tensor.h" +#include "pipeline/jit/fallback.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/parse.h" #include "include/common/utils/python_adapter.h" @@ -271,6 +272,16 @@ bool HasVariableLenAttr(const py::object &obj) { return py::hasattr(obj, variable_len_attr) && py::cast(py::getattr(obj, variable_len_attr)); } +AnfNodePtr ConvertInterpretedObjForResolve(const AnfNodePtr &origin_node, const ValuePtr &convert_result, + const FuncGraphPtr &func_graph) { + if (convert_result->isa() && !origin_node->has_user_data("__py_interpret_local_value_flag__")) { + constexpr auto recursive_level = 2; + MS_LOG(DEBUG) << "Convert InterpretedObj for resolve, node: " << origin_node->DebugString(recursive_level); + return ConvertInterpretedObjectToPyExecute(func_graph, convert_result, origin_node); + } + return nullptr; +} + AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, const FuncGraphPtr &func_graph) { // When the cell is set recomputed, it should not use old scope from cache. MS_EXCEPTION_IF_NULL(origin_node); @@ -298,14 +309,20 @@ AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object & (IsPrimitiveCNode(origin_node, prim::kPrimPyInterpret) && !origin_node->interpret_internal_type()) || origin_node->interpret(); static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0"); - if (!support_fallback_runtime && !interpret_without_internal && convert_result->isa()) { + MS_EXCEPTION_IF_NULL(convert_result); + if (support_fallback_runtime) { + AnfNodePtr interpreted_output = ConvertInterpretedObjForResolve(origin_node, convert_result, func_graph); + if (interpreted_output != nullptr) { + return interpreted_output; + } + } else if (!interpret_without_internal && convert_result->isa()) { auto type_str = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_GET_TYPE, obj); MS_EXCEPTION(TypeError) << "Do not support to convert " << py::str(type_str) << " object into graph node." << ".\nFor more details, please refer to " << "https://mindspore.cn/docs/zh-CN/master/search.html?q=Do+not+support+to+convert+object" << "+into+graph+node&check_keywords=yes&area=default\n"; } - MS_EXCEPTION_IF_NULL(convert_result); + if (convert_result->isa() && has_recompute_scope) { UpdateDebugInfo(convert_result->cast(), origin_node->scope(), origin_node->debug_info()); } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 2adf504170d..232f9949b83 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -43,6 +43,7 @@ #include "include/common/utils/config_manager.h" #include "include/common/utils/convert_utils.h" #include "include/common/utils/convert_utils_py.h" +#include "include/common/utils/python_utils.h" #include "utils/ms_context.h" #include "utils/shape_utils.h" #include "utils/info.h" @@ -369,28 +370,28 @@ std::pair GetPyExecuteOutput(const AnfNodePtr &output, const B } } // namespace -std::string GetObjDesc(const py::object &source_obj) { +std::string GetObjDesc(const py::object &source) { std::string obj_desc; - if (py::hasattr(source_obj, parse::PYTHON_PARSE_METHOD)) { - auto cell_class_name = source_obj.attr("__class__").attr("__name__"); - auto ms_function_name = source_obj.attr(parse::PYTHON_PARSE_METHOD); + if (py::hasattr(source, parse::PYTHON_PARSE_METHOD)) { + auto cell_class_name = source.attr("__class__").attr("__name__"); + auto ms_function_name = source.attr(parse::PYTHON_PARSE_METHOD); obj_desc = "'" + py::cast(cell_class_name) + "." + py::cast(ms_function_name) + "'"; } else { - if (py::hasattr(source_obj, "__name__")) { - auto ms_function_name = source_obj.attr("__name__"); + if (py::hasattr(source, "__name__")) { + auto ms_function_name = source.attr("__name__"); obj_desc = "'" + py::cast(ms_function_name) + "'"; - } else if (py::isinstance(source_obj)) { - auto cell_class_name = source_obj.attr("__class__").attr("__name__"); + } else if (py::isinstance(source)) { + auto cell_class_name = source.attr("__class__").attr("__name__"); obj_desc = "'" + py::cast(cell_class_name) + ".construct'"; } else { - MS_EXCEPTION(TypeError) << "The source object is invalid: " << py::str(source_obj); + MS_EXCEPTION(TypeError) << "The source object is invalid: " << py::str(source); } } return obj_desc; } -void CheckArgsValid(const py::object &source_obj, const py::tuple &args) { - std::string obj_desc = GetObjDesc(source_obj); +void CheckArgsValid(const py::object &source, const py::tuple &args) { + std::string obj_desc = GetObjDesc(source); for (size_t i = 0; i < args.size(); i++) { if (!CheckArgValid(args[i])) { MS_EXCEPTION(TypeError) @@ -696,8 +697,8 @@ py::dict GraphExecutorPy::GetAllreduceFusion(const std::string &phase) { // Not support multi thread, not support nested call too. // Here using nested_called flg to avoid nested call. -void GraphExecutorPy::DelNetRes(const py::object &source_obj, const py::set &id) { - ClearArgCache(source_obj); +void GraphExecutorPy::DelNetRes(const py::object &source, const py::set &id) { + ClearArgCache(source); // Del all graphs by different phase for (auto item : id) { DelOneNetRes(item); @@ -881,38 +882,38 @@ void GraphExecutorPy::CleanCompileRes(const ResourcePtr &resource) { MS_LOG(INFO) << "Clean compile resource end"; } -bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs, - const py::object &phase_obj, bool use_vm) { +bool GraphExecutorPy::CompileInner(const py::object &source, const py::tuple &args, const py::dict &kwargs, + const py::object &phase, bool use_vm) { // Check if the phase is valid. - if ((!py::isinstance(phase_obj))) { + if ((!py::isinstance(phase))) { MS_LOG(ERROR) << "The `phase` must be string."; return false; } // Check if the function or net is valid. - if (py::isinstance(source_obj)) { + if (py::isinstance(source)) { MS_LOG(ERROR) << "The source object to compile should not be None."; return false; } // Check if the args of function or net is valid. - CheckArgsValid(source_obj, args); + CheckArgsValid(source, args); - auto phase = py::cast(phase_obj); - PhaseManager::GetInstance().set_phase(phase); - phase_ = phase; - auto obj_desc = GetObjDesc(source_obj); - MS_LOG(INFO) << "Start compiling, phase: " << phase; - MS_LOG(DEBUG) << "source: {" << py::str(source_obj) << "}\nargs: " << py::str(const_cast(args)) + source_ = py::cast(py::str(source)); + phase_ = py::cast(phase); + PhaseManager::GetInstance().set_phase(phase_); + auto obj_desc = GetObjDesc(source); + MS_LOG(INFO) << "Start compiling, phase: " << phase_; + MS_LOG(DEBUG) << "source: {" << source_ << "}\nargs: " << py::str(const_cast(args)) << "\nkwargs: " << py::str(const_cast(kwargs)); - EventMessage::PrintCompileStartMsg(phase, obj_desc); + EventMessage::PrintCompileStartMsg(phase_, obj_desc); ExecutorInfoPtr executor_info = std::make_shared(); - ResourcePtr resource = std::make_shared(source_obj); - InitCompileCacheInfo(resource, phase); + ResourcePtr resource = std::make_shared(source); + InitCompileCacheInfo(resource, phase_); bool use_compile_cache = resource->EnableCompileCache() && resource->func_graph(); ConfigManager::GetInstance().ResetQueue(queue_name_); - auto actions = GetPipeline(resource, phase, use_vm); - std::shared_ptr pip = std::make_shared(resource, FilterActions(actions, phase)); + auto actions = GetPipeline(resource, phase_, use_vm); + std::shared_ptr pip = std::make_shared(resource, FilterActions(actions, phase_)); if (pip->NeedCreateBackend()) { // Create backend asynchronously. @@ -932,26 +933,26 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); bool is_parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel || parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel; - bool is_auto_parallel = is_parallel_mode && !py::hasattr(source_obj, parallel::kSkipAutoParallelCompile) && - !py::hasattr(source_obj, parallel::kKeepInputUnchanged); + bool is_auto_parallel = is_parallel_mode && !py::hasattr(source, parallel::kSkipAutoParallelCompile) && + !py::hasattr(source, parallel::kKeepInputUnchanged); ConvertArgs(args, kwargs, is_auto_parallel, &args_abs, &arguments); resource->set_arguments(arguments); resource->set_args_abs(args_abs); executor_info->arg_list_size = args.size() + kwargs.size(); executor_info->resource = resource; - info_[phase] = executor_info; + info_[phase_] = executor_info; pip->Run(); // Save the compiled graph to MsPipeLine. - SaveCompiledGraph(phase); + SaveCompiledGraph(phase_); if (is_parallel_mode) { - ParallelPostProcess(phase, use_compile_cache); + ParallelPostProcess(phase_, use_compile_cache); } #ifdef ENABLE_DUMP_IR mindspore::RDR::Snapshot(); #endif CleanCompileRes(resource); - EventMessage::PrintCompileEndMsg(phase, obj_desc); + EventMessage::PrintCompileEndMsg(phase_, obj_desc); PhaseManager::GetInstance().ClearPhase(); MS_LOG(INFO) << "Finish compiling."; return true; @@ -1027,7 +1028,7 @@ std::vector GraphExecutorPy::FilterActions(const std::vector(phase) << "\n\tinner source: " << source; } - } - ReleaseResource(phase); - - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::type_error &ex) { - ReleaseResource(phase); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - ReleaseResource(phase); - throw py::value_error(ex); - } catch (const py::index_error &ex) { - ReleaseResource(phase); - throw py::index_error(ex); - } catch (const py::key_error &ex) { - ReleaseResource(phase); - throw py::key_error(ex); - } catch (const py::attribute_error &ex) { - ReleaseResource(phase); - throw py::attribute_error(ex); - } catch (const py::name_error &ex) { - ReleaseResource(phase); - throw py::name_error(ex); - } catch (const py::assertion_error &ex) { - ReleaseResource(phase); - throw py::assertion_error(ex); - } catch (const py::base_exception &ex) { - ReleaseResource(phase); - throw py::base_exception(ex); - } catch (const py::keyboard_interrupt &ex) { - ReleaseResource(phase); - throw py::keyboard_interrupt(ex); - } catch (const py::stop_iteration &ex) { - ReleaseResource(phase); - throw py::stop_iteration(ex); - } catch (const py::overflow_error &ex) { - ReleaseResource(phase); - throw py::overflow_error(ex); - } catch (const py::zero_division_error &ex) { - ReleaseResource(phase); - throw py::zero_division_error(ex); - } catch (const py::environment_error &ex) { - ReleaseResource(phase); - throw py::environment_error(ex); - } catch (const py::io_error &ex) { - ReleaseResource(phase); - throw py::io_error(ex); - } catch (const py::os_error &ex) { - ReleaseResource(phase); - throw py::os_error(ex); - } catch (const py::memory_error &ex) { - ReleaseResource(phase); - throw py::memory_error(ex); - } catch (const py::unbound_local_error &ex) { - ReleaseResource(phase); - throw py::unbound_local_error(ex); - } catch (const py::not_implemented_error &ex) { - ReleaseResource(phase); - throw py::not_implemented_error(ex); - } catch (const py::indentation_error &ex) { - ReleaseResource(phase); - throw py::indentation_error(ex); - } catch (const py::runtime_warning &ex) { - ReleaseResource(phase); - throw py::runtime_warning(ex); - } catch (const std::exception &ex) { - ReleaseResource(phase); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - ReleaseResource(phase); -#ifndef _MSC_VER - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; -#else - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: "; -#endif - } - return ret_value; + ProcessStatus::GetInstance().RecordStart("CompileInner"); + res = CompileInner(source, args, kwargs, phase, use_vm); + ProcessStatus::GetInstance().RecordEnd(); + ProcessStatus::GetInstance().Print(); + }, + [this, &phase]() { + if (!StaticAnalysisException::Instance().HasException()) { + // print function call stack info before release + std::string compile_exception_info = GetCompileExceptionInfo(); + if (!compile_exception_info.empty()) { + MS_LOG(ERROR) << compile_exception_info; + } + } + ReleaseResourceOnException(phase); + }, + [this, &phase]() { ReleaseResourceOnException(phase); }, [this, &phase]() { ReleaseResourceOnException(phase); }); + return res; } void CacheValidateFuncGraph(const ResourcePtr &resource) { @@ -1395,8 +1329,21 @@ std::pair GraphExecutorPy::GetPyExecuteOutputFromAddress(const return {py::none(), false}; } -py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_obj) { - // init for dynamic-obfuscated model infer +py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase) { + py::object res; + HandleExceptionRethrow( + [this, &res, &args, &phase]() { + executor_running_ = true; + res = RunInner(args, phase); + executor_running_ = false; + }, + [this]() { executor_running_ = false; }, [this]() { executor_running_ = false; }, + [this]() { executor_running_ = false; }); + return res; +} + +py::object GraphExecutorPy::RunInner(const py::tuple &args, const py::object &phase_obj) { + // Init for dynamic-obfuscated model infer (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count(); // Mindspore debugger notify main thread to exit after one step, and will not run next step #ifdef ENABLE_DEBUGGER diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index f445d20ffa4..2e46a6f74aa 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -74,20 +74,16 @@ class GraphExecutorPy : public std::enable_shared_from_this { ~GraphExecutorPy(); + bool Compile(const py::object &source, const py::tuple &args, const py::dict &kwargs, const py::object &phase, + bool use_vm); + py::object Run(const py::tuple &args, const py::object &phase); + const std::string &phase() const { return phase_; } const std::map &jit_config() const { return jit_config_; } void SaveCompiledGraph(const std::string &phase); void ConvertArgs(const py::tuple &args, const py::dict &kwargs, bool is_auto_parallel, abstract::AbstractBasePtrList *args_abs, std::vector *arguments); - bool CompileInner(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs, - const py::object &phase_obj, bool use_vm); - bool Compile(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs, const py::object &phase, - bool use_vm); - void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list); - - // for pynative mode when use_vm is on - py::object Run(const py::tuple &args, const py::object &phase_obj); std::pair GetPyExecuteOutputFromAddress(const py::object &res, const BaseRef &value); ResourcePtr GetResource(const std::string &phase); FuncGraphPtr GetFuncGraph(const std::string &phase); @@ -121,8 +117,8 @@ class GraphExecutorPy : public std::enable_shared_from_this { size_t GetNumOpsInfo(const std::string &phase); void SetNumOpsInfo(size_t num_ops); py::dict GetAllreduceFusion(const std::string &phase); - void DelNetRes(const py::object &source_obj, const py::set &id); - void ReleaseResource(const py::object &phase); + void DelNetRes(const py::object &source, const py::set &id); + void ReleaseResourceOnException(const py::object &phase); void CleanCompileRes(const ResourcePtr &resource); static void ClearRes(); void set_queue_name(const std::string &queue_name) { queue_name_ = queue_name; } @@ -153,11 +149,16 @@ class GraphExecutorPy : public std::enable_shared_from_this { // If enable compile cache, get the compile cache resource. void InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase); + bool CompileInner(const py::object &source, const py::tuple &args, const py::dict &kwargs, const py::object &phase, + bool use_vm); + py::object RunInner(const py::tuple &args, const py::object &phase); + std::map info_; static std::shared_ptr executor_; static std::mutex instance_lock_; std::map stra_dict_; - std::string phase_ = ""; + std::string phase_{""}; + std::string source_{""}; std::map jit_config_; std::map phase_to_num_op_info_; std::string queue_name_; @@ -166,14 +167,15 @@ class GraphExecutorPy : public std::enable_shared_from_this { bool compile_cache_consistent_{true}; py::dict weights_; std::map> cur_convert_input_; + bool executor_running_{false}; }; using GraphExecutorPyPtr = std::shared_ptr; std::string GetJitLevel(); -std::string GetObjDesc(const py::object &source_obj); +std::string GetObjDesc(const py::object &source); bool IsPhaseLoadFromMindIR(const std::string &phase); -void CheckArgsValid(const py::object &source_obj, const py::tuple &args); +void CheckArgsValid(const py::object &source, const py::tuple &args); py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple &inputs); bool InitDistribute(const std::map &options); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 8158e8e99df..72c9dda6ba6 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -1691,6 +1691,13 @@ EvalResultPtr GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEngine return StaticGetterInferred(converted_value, data_conf, out_conf, require_type); } +bool IsPyExecuteCNodeData(const AbstractBasePtr &data_abstract) { + if (data_abstract->has_user_data("__py_execute_cnode_flag__")) { + return true; + } + return false; +} + void CheckObjAttrValid(const TypePtr &data_type, const std::string &item_name, const AbstractBasePtr &data_args) { // Check if the obj's attr is invalid or decoratored by @jit_forbidden_register std::string data_type_str = TypeIdLabel(data_type->type_id()); @@ -1747,10 +1754,15 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name; } - CheckObjAttrValid(data_type, item_name, data_args); + constexpr auto recursive_level = 3; MS_LOG(DEBUG) << "Evaluate " << data_type->ToString() << " attribute: " << item_name - << ".\nnode: " << out_conf->node()->DebugString() << "\n" + << ".\nnode: " << out_conf->node()->DebugString(recursive_level) << "\n" << trace::GetDebugInfo(out_conf->node()->debug_info()); + auto cnode = dyn_cast(out_conf->node()); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsPyExecuteCNodeData(data_args)) { // Not check if the data is PyExecute CNode. + CheckObjAttrValid(data_type, item_name, data_args); + } auto res = InterpretGetAttrNode(args_abs_list, out_conf); if (res == nullptr) { MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name; @@ -2067,6 +2079,8 @@ EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abst shape = std::make_shared(shp); } AbstractBasePtr res = std::make_shared(type, shape); + // User data '__py_execute_cnode_flag__' is used by 'IsPyExecuteCNodeData' to check forward PyExecute CNode. + res->set_user_data("__py_execute_cnode_flag__", std::make_shared(true)); auto infer_result = std::make_shared(res, std::make_shared()); evaluator_cache_mgr_->SetValue(args_abs_list, infer_result); return infer_result; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 514078f01d9..ae443aa73d3 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -327,10 +327,12 @@ EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const A // Check if the operator input is PyExecute CNode. auto &func_node = inputs[0]; MS_EXCEPTION_IF_NULL(func_node); - constexpr auto debug_recursive_level = 2; - MS_LOG(DEBUG) << "Current CNode: " << cnode->DebugString(debug_recursive_level); + constexpr auto recursive_level = 2; + MS_LOG(DEBUG) << "Current CNode: " << cnode->DebugString(recursive_level) + << ", func_node: " << func_node->DebugString(recursive_level); auto prim = GetCNodePrimitiveWithoutDoSignature(func_node); - if (!IsPrimitiveEquals(prim, prim::kPrimGetAttr)) { // Optimize the performance. + if (!IsPrimitiveEquals(prim, prim::kPrimGetAttr) && + !IsPrimitiveEquals(prim, prim::kPrimPyExecute)) { // Optimize the performance. return nullptr; } AnfNodeConfigPtr func_conf = MakeConfig(func_node, conf->context(), conf->func_graph()); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index d564c7c1837..fffe890658c 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -29,6 +29,7 @@ #include "ir/cell.h" #include "abstract/utils.h" #include "include/common/utils/stub_tensor.h" +#include "include/common/utils/python_utils.h" namespace mindspore::pynative { std::shared_ptr PyNativeExecutor::executor_ = nullptr; @@ -43,46 +44,26 @@ T PyNativeExecutorTry(const std::function &method, const Arg const auto &inst = PyNativeExecutor::GetInstance(); MS_EXCEPTION_IF_NULL(inst); MS_EXCEPTION_IF_NULL(method); - try { - return method(args...); - } catch (const py::error_already_set &ex) { - // print function call stack info before release + auto already_set_error_handler = [&inst]() { + // Print function call stack info before release. std::ostringstream oss; trace::TraceGraphEval(); trace::GetEvalStackInfo(oss); - // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see - // these info from screen, no need to open log file to find these info + // Call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see + // these info from screen, no need to open log file to find these info. py::print(oss.str()); MS_LOG(ERROR) << oss.str(); inst->ClearRes(); - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::index_error &ex) { - inst->ClearRes(); - throw py::index_error(ex); - } catch (const py::value_error &ex) { - inst->ClearRes(); - throw py::value_error(ex); - } catch (const py::type_error &ex) { - inst->ClearRes(); - throw py::type_error(ex); - } catch (const py::name_error &ex) { - inst->ClearRes(); - throw py::name_error(ex); - } catch (const std::exception &ex) { - inst->ClearRes(); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - inst->ClearRes(); -#ifndef _MSC_VER - auto exception_type = abi::__cxa_current_exception_type(); - MS_EXCEPTION_IF_NULL(exception_type); - std::string ex_name(exception_type->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << ex_name; -#else - MS_LOG(EXCEPTION) << "Error occurred when compile graph."; -#endif + }; + + if constexpr (std::is_same_v) { + HandleExceptionRethrow([&method, &args...]() { method(args...); }, already_set_error_handler, + [&inst]() { inst->ClearRes(); }, [&inst]() { inst->ClearRes(); }); + } else { + T res; + HandleExceptionRethrow([&res, &method, &args...]() { res = method(args...); }, already_set_error_handler, + [&inst]() { inst->ClearRes(); }, [&inst]() { inst->ClearRes(); }); + return res; } } diff --git a/mindspore/ccsrc/pybind_api/ir/py_execute_py.h b/mindspore/ccsrc/pybind_api/ir/py_execute_py.h index 8811315ea20..571423637f9 100644 --- a/mindspore/ccsrc/pybind_api/ir/py_execute_py.h +++ b/mindspore/ccsrc/pybind_api/ir/py_execute_py.h @@ -29,6 +29,7 @@ #include "include/common/fallback.h" #include "mindspore/core/ops/py_execute.h" #include "mindspore/ccsrc/include/common/utils/convert_utils_py.h" +#include "mindspore/ccsrc/include/common/utils/python_utils.h" #include "mindspore/ccsrc/include/common/utils/python_adapter.h" #include "mindspore/ccsrc/include/common/utils/python_fallback_running.h" #include "mindspore/ccsrc/pipeline/jit/parse/data_converter.h" @@ -161,96 +162,20 @@ class PyExecuteInitializer { const AbstractBasePtrList &args_spec_list) { // We can't catch the pybind11 exception by py::builtin_exception or its base class, // so we have to list all pybind11 exceptions and catch one by one here. - try { - const auto &abs = opt::CppInferShapeAndType(primitive, args_spec_list); - MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << cnode->abstract() - << " to " << abs; - return abs; - } catch (const py::type_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::type_error(ss.str()); - } catch (const py::value_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::value_error(ss.str()); - } catch (const py::index_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::index_error(ss.str()); - } catch (const py::key_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::key_error(ss.str()); - } catch (const py::attribute_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::attribute_error(ss.str()); - } catch (const py::name_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::name_error(ss.str()); - } catch (const py::assertion_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::assertion_error(ss.str()); - } catch (const py::base_exception &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::base_exception(ss.str()); - } catch (const py::keyboard_interrupt &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::keyboard_interrupt(ss.str()); - } catch (const py::stop_iteration &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::stop_iteration(ss.str()); - } catch (const py::overflow_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::overflow_error(ss.str()); - } catch (const py::zero_division_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::zero_division_error(ss.str()); - } catch (const py::environment_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::environment_error(ss.str()); - } catch (const py::io_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::io_error(ss.str()); - } catch (const py::os_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::os_error(ss.str()); - } catch (const py::memory_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::memory_error(ss.str()); - } catch (const py::unbound_local_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::unbound_local_error(ss.str()); - } catch (const py::not_implemented_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::not_implemented_error(ss.str()); - } catch (const py::indentation_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::indentation_error(ss.str()); - } catch (const py::runtime_warning &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw py::runtime_warning(ss.str()); - } catch (const std::runtime_error &e) { - std::stringstream ss; - ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info()); - throw std::runtime_error(ss.str()); - } + AbstractBasePtr res; + std::function already_set_error_handler; + std::function other_error_handler; + std::function default_error_handler; + HandleExceptionRethrow( + [&res, &cnode, &primitive, &args_spec_list]() { + res = opt::CppInferShapeAndType(primitive, args_spec_list); + MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << cnode->abstract() + << " to " << res; + return res; + }, + already_set_error_handler, other_error_handler, default_error_handler, + cnode->debug_info()); // Use debug_info to re-throw. + return res; } }; diff --git a/mindspore/ccsrc/utils/python_utils.cc b/mindspore/ccsrc/utils/python_utils.cc new file mode 100644 index 00000000000..572e2a18b07 --- /dev/null +++ b/mindspore/ccsrc/utils/python_utils.cc @@ -0,0 +1,307 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/common/utils/python_utils.h" + +#include "include/common/utils/python_adapter.h" +#include "pybind_api/pybind_patch.h" + +namespace mindspore { +void HandleExceptionRethrow(const std::function &main_func, + const std::function &already_set_error_handler, + const std::function &other_error_handler, + const std::function &default_error_handler, const DebugInfoPtr &debug_info) { + try { + if (!main_func) { + MS_LOG(ERROR) << "The 'main_func' should not be empty."; + return; + } + main_func(); + } catch (const py::error_already_set &ex) { + if (already_set_error_handler) { + already_set_error_handler(); + } + // Re-throw this exception to Python interpreter to handle it + throw(py::error_already_set(ex)); + } catch (const py::type_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::type_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::type_error(ss.str()); + } + } catch (const py::value_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::value_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::value_error(ss.str()); + } + } catch (const py::index_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::index_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::index_error(ss.str()); + } + } catch (const py::key_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::key_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::key_error(ss.str()); + } + } catch (const py::attribute_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::attribute_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::attribute_error(ss.str()); + } + } catch (const py::name_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::name_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::name_error(ss.str()); + } + } catch (const py::assertion_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::assertion_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::assertion_error(ss.str()); + } + } catch (const py::base_exception &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::base_exception(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::base_exception(ss.str()); + } + } catch (const py::keyboard_interrupt &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::keyboard_interrupt(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::keyboard_interrupt(ss.str()); + } + } catch (const py::stop_iteration &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::stop_iteration(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::stop_iteration(ss.str()); + } + } catch (const py::overflow_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::overflow_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::overflow_error(ss.str()); + } + } catch (const py::zero_division_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::zero_division_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::zero_division_error(ss.str()); + } + } catch (const py::environment_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::environment_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::environment_error(ss.str()); + } + } catch (const py::io_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::io_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::io_error(ss.str()); + } + } catch (const py::os_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::os_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::os_error(ss.str()); + } + } catch (const py::memory_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::memory_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::memory_error(ss.str()); + } + } catch (const py::unbound_local_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::unbound_local_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::unbound_local_error(ss.str()); + } + } catch (const py::not_implemented_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::not_implemented_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::not_implemented_error(ss.str()); + } + } catch (const py::indentation_error &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::indentation_error(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::indentation_error(ss.str()); + } + } catch (const py::runtime_warning &ex) { + if (other_error_handler) { + other_error_handler(); + } + + if (debug_info == nullptr) { + throw py::runtime_warning(ex); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw py::runtime_warning(ss.str()); + } + } catch (const std::exception &ex) { + if (other_error_handler) { + other_error_handler(); + } + + // Re-throw this exception to Python interpreter to handle it. + if (debug_info == nullptr) { + throw std::runtime_error(ex.what()); + } else { + std::stringstream ss; + ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info); + throw std::runtime_error(ss.str()); + } + } catch (...) { + if (default_error_handler) { + default_error_handler(); + } + +#ifndef _MSC_VER + auto exception_type = abi::__cxa_current_exception_type(); + MS_EXCEPTION_IF_NULL(exception_type); + std::string ex_name(exception_type->name()); + MS_LOG(EXCEPTION) << "Error occurred. Exception name: " << ex_name; +#else + MS_LOG(EXCEPTION) << "Error occurred."; +#endif + } +} +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/utils.cc b/mindspore/ccsrc/utils/utils.cc index 0aa2f5ac5a9..995f59c445c 100644 --- a/mindspore/ccsrc/utils/utils.cc +++ b/mindspore/ccsrc/utils/utils.cc @@ -15,8 +15,10 @@ */ #include "include/common/utils/utils.h" + #include #include + namespace mindspore { bool IsOneOfPosteriorOperator(const std::string &name) { const std::set kPosteriorOperatorSet = {kPullOpName}; diff --git a/mindspore/python/mindspore/_extends/parse/parser.py b/mindspore/python/mindspore/_extends/parse/parser.py index e96edc3c850..168d507d318 100644 --- a/mindspore/python/mindspore/_extends/parse/parser.py +++ b/mindspore/python/mindspore/_extends/parse/parser.py @@ -815,7 +815,7 @@ def eval_script(exp_str, params): except Exception as e: error_info = f"When eval '{exp_str}' by using JIT Fallback feature, an error occurred: " + str(e) + \ ". You can try to turn off JIT Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'." - logger.error(error_info) + logger.debug(error_info) raise e # Convert set to tuple. diff --git a/tests/st/fallback/test_graph_fallback_runtime.py b/tests/st/fallback/test_graph_fallback_runtime.py index eced36d96a5..9d0493d71d0 100644 --- a/tests/st/fallback/test_graph_fallback_runtime.py +++ b/tests/st/fallback/test_graph_fallback_runtime.py @@ -873,3 +873,67 @@ def test_parser_fallback_nested_class_outer_grad(): net = NestedNet() output = ops.grad(net)(mutable(x), y) assert output == 0 + + +class UserDefinedNet: + def __init__(self): + self.value = 10 + + def __call__(self, x): + return self.value * x + + +class UserDefinedMsFunctionCallNet: + def __init__(self): + self.value = 10 + + @ms.jit + def __call__(self, x): + return self.value * x + + +class UNet(ms.nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + + def construct(self, x): + out = x * x + out = self.net(x) + out = out + out + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_resolve_cust_class(): + """ + Feature: Syntax resolve. + Description: Graph syntax resolve support custom class input. + Expectation: No error. + """ + net = UNet(UserDefinedNet()) + x = np.array([10], np.float32) + output = net(ms.Tensor(x)) + print(output) # The output should == 200, but failed, check later. + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_resolve_cust_ms_function_call_class(): + """ + Feature: Syntax resolve. + Description: Graph syntax resolve support custom class input. + Expectation: No error. + """ + net = UNet(UserDefinedMsFunctionCallNet()) + x = np.array([10, 10], np.float32) + with pytest.raises(RuntimeError) as err: + net(ms.Tensor(x)) + assert "Nested execution during JIT execution is not supported." in str(err.value)