[JIT Fallback] Support JIT Fallback Resolve routine; Refactoring exception re-throw routine; Throw exception if there's nested JIT execution.

This commit is contained in:
张清华 2023-02-22 15:26:43 +08:00
parent fcd672091d
commit dec3b54904
16 changed files with 579 additions and 282 deletions

View File

@ -18,10 +18,9 @@ mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspo
mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic
mindspore/mindspore/ccsrc/pipeline/jit/init.cc:PYBIND11_MODULE
mindspore/mindspore/ccsrc/pipeline/jit/parse/resolve.cc:mindspore::parse::ResolveObjectToNode
mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile
mindspore/mindspore/ccsrc/utils/python_utils.cc:mindspore::HandleExceptionRethrow
mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstract::ConvertAbstractToPython
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.h:mindspore::PyExceptionInitializer::HandleExceptionPy
mindspore/mindspore/ccsrc/pybind_api/ir/py_execute_py.h:mindspore::PyExecuteInitializer::CppInferShapeAndTypePy
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc:aicpu::ARMDropOutGenMaskKernel

View File

@ -58,7 +58,7 @@ class SwitchSimplify : public OptimizerCaller {
<< " not support this condition value: " << value_ptr->ToString();
}
MS_LOG(DEBUG) << "condition value: " << value_ptr->ToString() << " bool:" << cond_value;
MS_LOG(DEBUG) << "condition value: " << value_ptr->ToString() << ", cond: " << cond_value;
if (cond_value) {
return true_br.GetNode(node);
}

View File

@ -0,0 +1,32 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_UTILS_H_
#define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_UTILS_H_
#include <functional>
#include "include/common/visible.h"
#include "utils/trace_base.h"
namespace mindspore {
COMMON_EXPORT void HandleExceptionRethrow(const std::function<void(void)> &main_func,
const std::function<void(void)> &already_set_error_handler,
const std::function<void(void)> &other_error_handler,
const std::function<void(void)> &default_error_handler,
const DebugInfoPtr &debug_info = nullptr);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_UTILS_H_

View File

@ -58,7 +58,8 @@ AnfNodePtr ConvertInterpretedObjectToPyExecute(const FuncGraphPtr &fg, const Val
NewValueNode(std::make_shared<ValueTuple>(keys)),
NewValueNode(std::make_shared<ValueTuple>(values))});
constexpr auto debug_recursive_level = 2;
MS_LOG(DEBUG) << "interpreted_cnode: " << interpreted_cnode->DebugString(debug_recursive_level);
MS_LOG(DEBUG) << "original node: " << node->DebugString(debug_recursive_level)
<< ", interpreted_cnode: " << interpreted_cnode->DebugString(debug_recursive_level);
interpreted_cnode->set_debug_info(node->debug_info());
return interpreted_cnode;
}

View File

@ -2840,7 +2840,10 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod
const auto &iter = values.find(id_str);
if (iter != values.end()) {
(void)filter_keys.emplace_back(keys[iter->first]);
(void)filter_values.emplace_back(iter->second);
auto &val_node = iter->second;
// '__py_interpret_local_value_flag__' is used by 'ConvertInterpretedObjForResolve' not to convert PyExecute.
val_node->set_user_data("__py_interpret_local_value_flag__", std::make_shared<bool>(true));
(void)filter_values.emplace_back(val_node);
}
}
@ -2848,11 +2851,12 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod
// Update the valued node if it need interpreting.
constexpr int recursive_level = 2;
MS_EXCEPTION_IF_NULL(block->func_graph());
AnfNodePtr interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: `" << script_text
<< "`,\nvalue_node: " << value_node->DebugString(recursive_level)
<< ",\nglobal_dict_node: " << global_dict_node->ToString()
<< ",\nlocal_dict_node: " << local_dict_node->DebugString(recursive_level);
AnfNodePtr interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
<< ",\nlocal_dict_node: " << local_dict_node->DebugString(recursive_level)
<< ",\ninterpreted_node: " << interpreted_node->DebugString(recursive_level);
// Print a hint for user.
auto line_info = trace::GetDebugInfo(value_node->debug_info());

View File

@ -25,6 +25,7 @@
#include "ir/param_info.h"
#include "ir/value.h"
#include "ir/map_tensor.h"
#include "pipeline/jit/fallback.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/parse/parse.h"
#include "include/common/utils/python_adapter.h"
@ -271,6 +272,16 @@ bool HasVariableLenAttr(const py::object &obj) {
return py::hasattr(obj, variable_len_attr) && py::cast<bool>(py::getattr(obj, variable_len_attr));
}
AnfNodePtr ConvertInterpretedObjForResolve(const AnfNodePtr &origin_node, const ValuePtr &convert_result,
const FuncGraphPtr &func_graph) {
if (convert_result->isa<InterpretedObject>() && !origin_node->has_user_data("__py_interpret_local_value_flag__")) {
constexpr auto recursive_level = 2;
MS_LOG(DEBUG) << "Convert InterpretedObj for resolve, node: " << origin_node->DebugString(recursive_level);
return ConvertInterpretedObjectToPyExecute(func_graph, convert_result, origin_node);
}
return nullptr;
}
AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, const FuncGraphPtr &func_graph) {
// When the cell is set recomputed, it should not use old scope from cache.
MS_EXCEPTION_IF_NULL(origin_node);
@ -298,14 +309,20 @@ AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &
(IsPrimitiveCNode(origin_node, prim::kPrimPyInterpret) && !origin_node->interpret_internal_type()) ||
origin_node->interpret();
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (!support_fallback_runtime && !interpret_without_internal && convert_result->isa<InterpretedObject>()) {
MS_EXCEPTION_IF_NULL(convert_result);
if (support_fallback_runtime) {
AnfNodePtr interpreted_output = ConvertInterpretedObjForResolve(origin_node, convert_result, func_graph);
if (interpreted_output != nullptr) {
return interpreted_output;
}
} else if (!interpret_without_internal && convert_result->isa<InterpretedObject>()) {
auto type_str = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_GET_TYPE, obj);
MS_EXCEPTION(TypeError) << "Do not support to convert " << py::str(type_str) << " object into graph node."
<< ".\nFor more details, please refer to "
<< "https://mindspore.cn/docs/zh-CN/master/search.html?q=Do+not+support+to+convert+object"
<< "+into+graph+node&check_keywords=yes&area=default\n";
}
MS_EXCEPTION_IF_NULL(convert_result);
if (convert_result->isa<FuncGraph>() && has_recompute_scope) {
UpdateDebugInfo(convert_result->cast<FuncGraphPtr>(), origin_node->scope(), origin_node->debug_info());
}

View File

@ -43,6 +43,7 @@
#include "include/common/utils/config_manager.h"
#include "include/common/utils/convert_utils.h"
#include "include/common/utils/convert_utils_py.h"
#include "include/common/utils/python_utils.h"
#include "utils/ms_context.h"
#include "utils/shape_utils.h"
#include "utils/info.h"
@ -369,28 +370,28 @@ std::pair<py::object, bool> GetPyExecuteOutput(const AnfNodePtr &output, const B
}
} // namespace
std::string GetObjDesc(const py::object &source_obj) {
std::string GetObjDesc(const py::object &source) {
std::string obj_desc;
if (py::hasattr(source_obj, parse::PYTHON_PARSE_METHOD)) {
auto cell_class_name = source_obj.attr("__class__").attr("__name__");
auto ms_function_name = source_obj.attr(parse::PYTHON_PARSE_METHOD);
if (py::hasattr(source, parse::PYTHON_PARSE_METHOD)) {
auto cell_class_name = source.attr("__class__").attr("__name__");
auto ms_function_name = source.attr(parse::PYTHON_PARSE_METHOD);
obj_desc = "'" + py::cast<std::string>(cell_class_name) + "." + py::cast<std::string>(ms_function_name) + "'";
} else {
if (py::hasattr(source_obj, "__name__")) {
auto ms_function_name = source_obj.attr("__name__");
if (py::hasattr(source, "__name__")) {
auto ms_function_name = source.attr("__name__");
obj_desc = "'" + py::cast<std::string>(ms_function_name) + "'";
} else if (py::isinstance<Cell>(source_obj)) {
auto cell_class_name = source_obj.attr("__class__").attr("__name__");
} else if (py::isinstance<Cell>(source)) {
auto cell_class_name = source.attr("__class__").attr("__name__");
obj_desc = "'" + py::cast<std::string>(cell_class_name) + ".construct'";
} else {
MS_EXCEPTION(TypeError) << "The source object is invalid: " << py::str(source_obj);
MS_EXCEPTION(TypeError) << "The source object is invalid: " << py::str(source);
}
}
return obj_desc;
}
void CheckArgsValid(const py::object &source_obj, const py::tuple &args) {
std::string obj_desc = GetObjDesc(source_obj);
void CheckArgsValid(const py::object &source, const py::tuple &args) {
std::string obj_desc = GetObjDesc(source);
for (size_t i = 0; i < args.size(); i++) {
if (!CheckArgValid(args[i])) {
MS_EXCEPTION(TypeError)
@ -696,8 +697,8 @@ py::dict GraphExecutorPy::GetAllreduceFusion(const std::string &phase) {
// Not support multi thread, not support nested call too.
// Here using nested_called flg to avoid nested call.
void GraphExecutorPy::DelNetRes(const py::object &source_obj, const py::set &id) {
ClearArgCache(source_obj);
void GraphExecutorPy::DelNetRes(const py::object &source, const py::set &id) {
ClearArgCache(source);
// Del all graphs by different phase
for (auto item : id) {
DelOneNetRes(item);
@ -881,38 +882,38 @@ void GraphExecutorPy::CleanCompileRes(const ResourcePtr &resource) {
MS_LOG(INFO) << "Clean compile resource end";
}
bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs,
const py::object &phase_obj, bool use_vm) {
bool GraphExecutorPy::CompileInner(const py::object &source, const py::tuple &args, const py::dict &kwargs,
const py::object &phase, bool use_vm) {
// Check if the phase is valid.
if ((!py::isinstance<py::str>(phase_obj))) {
if ((!py::isinstance<py::str>(phase))) {
MS_LOG(ERROR) << "The `phase` must be string.";
return false;
}
// Check if the function or net is valid.
if (py::isinstance<py::none>(source_obj)) {
if (py::isinstance<py::none>(source)) {
MS_LOG(ERROR) << "The source object to compile should not be None.";
return false;
}
// Check if the args of function or net is valid.
CheckArgsValid(source_obj, args);
CheckArgsValid(source, args);
auto phase = py::cast<std::string>(phase_obj);
PhaseManager::GetInstance().set_phase(phase);
phase_ = phase;
auto obj_desc = GetObjDesc(source_obj);
MS_LOG(INFO) << "Start compiling, phase: " << phase;
MS_LOG(DEBUG) << "source: {" << py::str(source_obj) << "}\nargs: " << py::str(const_cast<py::tuple &>(args))
source_ = py::cast<std::string>(py::str(source));
phase_ = py::cast<std::string>(phase);
PhaseManager::GetInstance().set_phase(phase_);
auto obj_desc = GetObjDesc(source);
MS_LOG(INFO) << "Start compiling, phase: " << phase_;
MS_LOG(DEBUG) << "source: {" << source_ << "}\nargs: " << py::str(const_cast<py::tuple &>(args))
<< "\nkwargs: " << py::str(const_cast<py::dict &>(kwargs));
EventMessage::PrintCompileStartMsg(phase, obj_desc);
EventMessage::PrintCompileStartMsg(phase_, obj_desc);
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
ResourcePtr resource = std::make_shared<Resource>(source_obj);
InitCompileCacheInfo(resource, phase);
ResourcePtr resource = std::make_shared<Resource>(source);
InitCompileCacheInfo(resource, phase_);
bool use_compile_cache = resource->EnableCompileCache() && resource->func_graph();
ConfigManager::GetInstance().ResetQueue(queue_name_);
auto actions = GetPipeline(resource, phase, use_vm);
std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, FilterActions(actions, phase));
auto actions = GetPipeline(resource, phase_, use_vm);
std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, FilterActions(actions, phase_));
if (pip->NeedCreateBackend()) {
// Create backend asynchronously.
@ -932,26 +933,26 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
bool is_parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel ||
parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel;
bool is_auto_parallel = is_parallel_mode && !py::hasattr(source_obj, parallel::kSkipAutoParallelCompile) &&
!py::hasattr(source_obj, parallel::kKeepInputUnchanged);
bool is_auto_parallel = is_parallel_mode && !py::hasattr(source, parallel::kSkipAutoParallelCompile) &&
!py::hasattr(source, parallel::kKeepInputUnchanged);
ConvertArgs(args, kwargs, is_auto_parallel, &args_abs, &arguments);
resource->set_arguments(arguments);
resource->set_args_abs(args_abs);
executor_info->arg_list_size = args.size() + kwargs.size();
executor_info->resource = resource;
info_[phase] = executor_info;
info_[phase_] = executor_info;
pip->Run();
// Save the compiled graph to MsPipeLine.
SaveCompiledGraph(phase);
SaveCompiledGraph(phase_);
if (is_parallel_mode) {
ParallelPostProcess(phase, use_compile_cache);
ParallelPostProcess(phase_, use_compile_cache);
}
#ifdef ENABLE_DUMP_IR
mindspore::RDR::Snapshot();
#endif
CleanCompileRes(resource);
EventMessage::PrintCompileEndMsg(phase, obj_desc);
EventMessage::PrintCompileEndMsg(phase_, obj_desc);
PhaseManager::GetInstance().ClearPhase();
MS_LOG(INFO) << "Finish compiling.";
return true;
@ -1027,7 +1028,7 @@ std::vector<ActionItem> GraphExecutorPy::FilterActions(const std::vector<ActionI
return filtered_actions;
}
void GraphExecutorPy::ReleaseResource(const py::object &phase) {
void GraphExecutorPy::ReleaseResourceOnException(const py::object &phase) {
bool clear = false;
// Be sure the pointer res destroyed before do DelOneNetRes.
{
@ -1043,100 +1044,33 @@ void GraphExecutorPy::ReleaseResource(const py::object &phase) {
}
}
bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs,
bool GraphExecutorPy::Compile(const py::object &source, const py::tuple &args, const py::dict &kwargs,
const py::object &phase, bool use_vm) {
bool ret_value = false;
try {
ProcessStatus::GetInstance().RecordStart("CompileInner");
ret_value = CompileInner(source_obj, args, kwargs, phase, use_vm);
ProcessStatus::GetInstance().RecordEnd();
ProcessStatus::GetInstance().Print();
} catch (const py::error_already_set &ex) {
if (!StaticAnalysisException::Instance().HasException()) {
// print function call stack info before release
std::string compile_exception_info = GetCompileExceptionInfo();
if (!compile_exception_info.empty()) {
MS_LOG(ERROR) << compile_exception_info;
bool res = false;
HandleExceptionRethrow(
[this, &res, &source, &args, &kwargs, &phase, use_vm]() {
if (executor_running_) {
MS_LOG(EXCEPTION) << "Nested execution during JIT execution is not supported."
<< "\n\touter phase: " << phase_ << "\n\touter source: " << source_
<< "\n\tinner phase: " << py::cast<std::string>(phase) << "\n\tinner source: " << source;
}
}
ReleaseResource(phase);
// re-throw this exception to Python interpreter to handle it
throw(py::error_already_set(ex));
} catch (const py::type_error &ex) {
ReleaseResource(phase);
throw py::type_error(ex);
} catch (const py::value_error &ex) {
ReleaseResource(phase);
throw py::value_error(ex);
} catch (const py::index_error &ex) {
ReleaseResource(phase);
throw py::index_error(ex);
} catch (const py::key_error &ex) {
ReleaseResource(phase);
throw py::key_error(ex);
} catch (const py::attribute_error &ex) {
ReleaseResource(phase);
throw py::attribute_error(ex);
} catch (const py::name_error &ex) {
ReleaseResource(phase);
throw py::name_error(ex);
} catch (const py::assertion_error &ex) {
ReleaseResource(phase);
throw py::assertion_error(ex);
} catch (const py::base_exception &ex) {
ReleaseResource(phase);
throw py::base_exception(ex);
} catch (const py::keyboard_interrupt &ex) {
ReleaseResource(phase);
throw py::keyboard_interrupt(ex);
} catch (const py::stop_iteration &ex) {
ReleaseResource(phase);
throw py::stop_iteration(ex);
} catch (const py::overflow_error &ex) {
ReleaseResource(phase);
throw py::overflow_error(ex);
} catch (const py::zero_division_error &ex) {
ReleaseResource(phase);
throw py::zero_division_error(ex);
} catch (const py::environment_error &ex) {
ReleaseResource(phase);
throw py::environment_error(ex);
} catch (const py::io_error &ex) {
ReleaseResource(phase);
throw py::io_error(ex);
} catch (const py::os_error &ex) {
ReleaseResource(phase);
throw py::os_error(ex);
} catch (const py::memory_error &ex) {
ReleaseResource(phase);
throw py::memory_error(ex);
} catch (const py::unbound_local_error &ex) {
ReleaseResource(phase);
throw py::unbound_local_error(ex);
} catch (const py::not_implemented_error &ex) {
ReleaseResource(phase);
throw py::not_implemented_error(ex);
} catch (const py::indentation_error &ex) {
ReleaseResource(phase);
throw py::indentation_error(ex);
} catch (const py::runtime_warning &ex) {
ReleaseResource(phase);
throw py::runtime_warning(ex);
} catch (const std::exception &ex) {
ReleaseResource(phase);
// re-throw this exception to Python interpreter to handle it
throw(std::runtime_error(ex.what()));
} catch (...) {
ReleaseResource(phase);
#ifndef _MSC_VER
std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
#else
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: ";
#endif
}
return ret_value;
ProcessStatus::GetInstance().RecordStart("CompileInner");
res = CompileInner(source, args, kwargs, phase, use_vm);
ProcessStatus::GetInstance().RecordEnd();
ProcessStatus::GetInstance().Print();
},
[this, &phase]() {
if (!StaticAnalysisException::Instance().HasException()) {
// print function call stack info before release
std::string compile_exception_info = GetCompileExceptionInfo();
if (!compile_exception_info.empty()) {
MS_LOG(ERROR) << compile_exception_info;
}
}
ReleaseResourceOnException(phase);
},
[this, &phase]() { ReleaseResourceOnException(phase); }, [this, &phase]() { ReleaseResourceOnException(phase); });
return res;
}
void CacheValidateFuncGraph(const ResourcePtr &resource) {
@ -1395,8 +1329,21 @@ std::pair<py::object, bool> GraphExecutorPy::GetPyExecuteOutputFromAddress(const
return {py::none(), false};
}
py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_obj) {
// init for dynamic-obfuscated model infer
py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase) {
py::object res;
HandleExceptionRethrow(
[this, &res, &args, &phase]() {
executor_running_ = true;
res = RunInner(args, phase);
executor_running_ = false;
},
[this]() { executor_running_ = false; }, [this]() { executor_running_ = false; },
[this]() { executor_running_ = false; });
return res;
}
py::object GraphExecutorPy::RunInner(const py::tuple &args, const py::object &phase_obj) {
// Init for dynamic-obfuscated model infer
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
// Mindspore debugger notify main thread to exit after one step, and will not run next step
#ifdef ENABLE_DEBUGGER

View File

@ -74,20 +74,16 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
~GraphExecutorPy();
bool Compile(const py::object &source, const py::tuple &args, const py::dict &kwargs, const py::object &phase,
bool use_vm);
py::object Run(const py::tuple &args, const py::object &phase);
const std::string &phase() const { return phase_; }
const std::map<std::string, std::string> &jit_config() const { return jit_config_; }
void SaveCompiledGraph(const std::string &phase);
void ConvertArgs(const py::tuple &args, const py::dict &kwargs, bool is_auto_parallel,
abstract::AbstractBasePtrList *args_abs, std::vector<ValuePtr> *arguments);
bool CompileInner(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs,
const py::object &phase_obj, bool use_vm);
bool Compile(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs, const py::object &phase,
bool use_vm);
void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list);
// for pynative mode when use_vm is on
py::object Run(const py::tuple &args, const py::object &phase_obj);
std::pair<py::object, bool> GetPyExecuteOutputFromAddress(const py::object &res, const BaseRef &value);
ResourcePtr GetResource(const std::string &phase);
FuncGraphPtr GetFuncGraph(const std::string &phase);
@ -121,8 +117,8 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
size_t GetNumOpsInfo(const std::string &phase);
void SetNumOpsInfo(size_t num_ops);
py::dict GetAllreduceFusion(const std::string &phase);
void DelNetRes(const py::object &source_obj, const py::set &id);
void ReleaseResource(const py::object &phase);
void DelNetRes(const py::object &source, const py::set &id);
void ReleaseResourceOnException(const py::object &phase);
void CleanCompileRes(const ResourcePtr &resource);
static void ClearRes();
void set_queue_name(const std::string &queue_name) { queue_name_ = queue_name; }
@ -153,11 +149,16 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
// If enable compile cache, get the compile cache resource.
void InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase);
bool CompileInner(const py::object &source, const py::tuple &args, const py::dict &kwargs, const py::object &phase,
bool use_vm);
py::object RunInner(const py::tuple &args, const py::object &phase);
std::map<std::string, ExecutorInfoPtr> info_;
static std::shared_ptr<GraphExecutorPy> executor_;
static std::mutex instance_lock_;
std::map<std::string, py::dict> stra_dict_;
std::string phase_ = "";
std::string phase_{""};
std::string source_{""};
std::map<std::string, std::string> jit_config_;
std::map<std::string, size_t> phase_to_num_op_info_;
std::string queue_name_;
@ -166,14 +167,15 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
bool compile_cache_consistent_{true};
py::dict weights_;
std::map<PyObject *, std::pair<ValuePtr, AbstractBasePtr>> cur_convert_input_;
bool executor_running_{false};
};
using GraphExecutorPyPtr = std::shared_ptr<GraphExecutorPy>;
std::string GetJitLevel();
std::string GetObjDesc(const py::object &source_obj);
std::string GetObjDesc(const py::object &source);
bool IsPhaseLoadFromMindIR(const std::string &phase);
void CheckArgsValid(const py::object &source_obj, const py::tuple &args);
void CheckArgsValid(const py::object &source, const py::tuple &args);
py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple &inputs);
bool InitDistribute(const std::map<std::string, std::string> &options);

View File

@ -1691,6 +1691,13 @@ EvalResultPtr GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEngine
return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
}
bool IsPyExecuteCNodeData(const AbstractBasePtr &data_abstract) {
if (data_abstract->has_user_data("__py_execute_cnode_flag__")) {
return true;
}
return false;
}
void CheckObjAttrValid(const TypePtr &data_type, const std::string &item_name, const AbstractBasePtr &data_args) {
// Check if the obj's attr is invalid or decoratored by @jit_forbidden_register
std::string data_type_str = TypeIdLabel(data_type->type_id());
@ -1747,10 +1754,15 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
}
CheckObjAttrValid(data_type, item_name, data_args);
constexpr auto recursive_level = 3;
MS_LOG(DEBUG) << "Evaluate " << data_type->ToString() << " attribute: " << item_name
<< ".\nnode: " << out_conf->node()->DebugString() << "\n"
<< ".\nnode: " << out_conf->node()->DebugString(recursive_level) << "\n"
<< trace::GetDebugInfo(out_conf->node()->debug_info());
auto cnode = dyn_cast<CNode>(out_conf->node());
MS_EXCEPTION_IF_NULL(cnode);
if (!IsPyExecuteCNodeData(data_args)) { // Not check if the data is PyExecute CNode.
CheckObjAttrValid(data_type, item_name, data_args);
}
auto res = InterpretGetAttrNode(args_abs_list, out_conf);
if (res == nullptr) {
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
@ -2067,6 +2079,8 @@ EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abst
shape = std::make_shared<Shape>(shp);
}
AbstractBasePtr res = std::make_shared<AbstractTensor>(type, shape);
// User data '__py_execute_cnode_flag__' is used by 'IsPyExecuteCNodeData' to check forward PyExecute CNode.
res->set_user_data("__py_execute_cnode_flag__", std::make_shared<bool>(true));
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;

View File

@ -327,10 +327,12 @@ EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const A
// Check if the operator input is PyExecute CNode.
auto &func_node = inputs[0];
MS_EXCEPTION_IF_NULL(func_node);
constexpr auto debug_recursive_level = 2;
MS_LOG(DEBUG) << "Current CNode: " << cnode->DebugString(debug_recursive_level);
constexpr auto recursive_level = 2;
MS_LOG(DEBUG) << "Current CNode: " << cnode->DebugString(recursive_level)
<< ", func_node: " << func_node->DebugString(recursive_level);
auto prim = GetCNodePrimitiveWithoutDoSignature(func_node);
if (!IsPrimitiveEquals(prim, prim::kPrimGetAttr)) { // Optimize the performance.
if (!IsPrimitiveEquals(prim, prim::kPrimGetAttr) &&
!IsPrimitiveEquals(prim, prim::kPrimPyExecute)) { // Optimize the performance.
return nullptr;
}
AnfNodeConfigPtr func_conf = MakeConfig(func_node, conf->context(), conf->func_graph());

View File

@ -29,6 +29,7 @@
#include "ir/cell.h"
#include "abstract/utils.h"
#include "include/common/utils/stub_tensor.h"
#include "include/common/utils/python_utils.h"
namespace mindspore::pynative {
std::shared_ptr<PyNativeExecutor> PyNativeExecutor::executor_ = nullptr;
@ -43,46 +44,26 @@ T PyNativeExecutorTry(const std::function<T(const Args &...)> &method, const Arg
const auto &inst = PyNativeExecutor::GetInstance();
MS_EXCEPTION_IF_NULL(inst);
MS_EXCEPTION_IF_NULL(method);
try {
return method(args...);
} catch (const py::error_already_set &ex) {
// print function call stack info before release
auto already_set_error_handler = [&inst]() {
// Print function call stack info before release.
std::ostringstream oss;
trace::TraceGraphEval();
trace::GetEvalStackInfo(oss);
// call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
// these info from screen, no need to open log file to find these info
// Call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
// these info from screen, no need to open log file to find these info.
py::print(oss.str());
MS_LOG(ERROR) << oss.str();
inst->ClearRes();
// re-throw this exception to Python interpreter to handle it
throw(py::error_already_set(ex));
} catch (const py::index_error &ex) {
inst->ClearRes();
throw py::index_error(ex);
} catch (const py::value_error &ex) {
inst->ClearRes();
throw py::value_error(ex);
} catch (const py::type_error &ex) {
inst->ClearRes();
throw py::type_error(ex);
} catch (const py::name_error &ex) {
inst->ClearRes();
throw py::name_error(ex);
} catch (const std::exception &ex) {
inst->ClearRes();
// re-throw this exception to Python interpreter to handle it
throw(std::runtime_error(ex.what()));
} catch (...) {
inst->ClearRes();
#ifndef _MSC_VER
auto exception_type = abi::__cxa_current_exception_type();
MS_EXCEPTION_IF_NULL(exception_type);
std::string ex_name(exception_type->name());
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << ex_name;
#else
MS_LOG(EXCEPTION) << "Error occurred when compile graph.";
#endif
};
if constexpr (std::is_same_v<T, void>) {
HandleExceptionRethrow([&method, &args...]() { method(args...); }, already_set_error_handler,
[&inst]() { inst->ClearRes(); }, [&inst]() { inst->ClearRes(); });
} else {
T res;
HandleExceptionRethrow([&res, &method, &args...]() { res = method(args...); }, already_set_error_handler,
[&inst]() { inst->ClearRes(); }, [&inst]() { inst->ClearRes(); });
return res;
}
}

View File

@ -29,6 +29,7 @@
#include "include/common/fallback.h"
#include "mindspore/core/ops/py_execute.h"
#include "mindspore/ccsrc/include/common/utils/convert_utils_py.h"
#include "mindspore/ccsrc/include/common/utils/python_utils.h"
#include "mindspore/ccsrc/include/common/utils/python_adapter.h"
#include "mindspore/ccsrc/include/common/utils/python_fallback_running.h"
#include "mindspore/ccsrc/pipeline/jit/parse/data_converter.h"
@ -161,96 +162,20 @@ class PyExecuteInitializer {
const AbstractBasePtrList &args_spec_list) {
// We can't catch the pybind11 exception by py::builtin_exception or its base class,
// so we have to list all pybind11 exceptions and catch one by one here.
try {
const auto &abs = opt::CppInferShapeAndType(primitive, args_spec_list);
MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << cnode->abstract()
<< " to " << abs;
return abs;
} catch (const py::type_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::type_error(ss.str());
} catch (const py::value_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::value_error(ss.str());
} catch (const py::index_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::index_error(ss.str());
} catch (const py::key_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::key_error(ss.str());
} catch (const py::attribute_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::attribute_error(ss.str());
} catch (const py::name_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::name_error(ss.str());
} catch (const py::assertion_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::assertion_error(ss.str());
} catch (const py::base_exception &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::base_exception(ss.str());
} catch (const py::keyboard_interrupt &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::keyboard_interrupt(ss.str());
} catch (const py::stop_iteration &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::stop_iteration(ss.str());
} catch (const py::overflow_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::overflow_error(ss.str());
} catch (const py::zero_division_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::zero_division_error(ss.str());
} catch (const py::environment_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::environment_error(ss.str());
} catch (const py::io_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::io_error(ss.str());
} catch (const py::os_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::os_error(ss.str());
} catch (const py::memory_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::memory_error(ss.str());
} catch (const py::unbound_local_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::unbound_local_error(ss.str());
} catch (const py::not_implemented_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::not_implemented_error(ss.str());
} catch (const py::indentation_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::indentation_error(ss.str());
} catch (const py::runtime_warning &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw py::runtime_warning(ss.str());
} catch (const std::runtime_error &e) {
std::stringstream ss;
ss << e.what() << ".\n\n" << trace::GetDebugInfo(cnode->debug_info());
throw std::runtime_error(ss.str());
}
AbstractBasePtr res;
std::function<void(void)> already_set_error_handler;
std::function<void(void)> other_error_handler;
std::function<void(void)> default_error_handler;
HandleExceptionRethrow(
[&res, &cnode, &primitive, &args_spec_list]() {
res = opt::CppInferShapeAndType(primitive, args_spec_list);
MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << cnode->abstract()
<< " to " << res;
return res;
},
already_set_error_handler, other_error_handler, default_error_handler,
cnode->debug_info()); // Use debug_info to re-throw.
return res;
}
};

View File

@ -0,0 +1,307 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/common/utils/python_utils.h"
#include "include/common/utils/python_adapter.h"
#include "pybind_api/pybind_patch.h"
namespace mindspore {
void HandleExceptionRethrow(const std::function<void(void)> &main_func,
const std::function<void(void)> &already_set_error_handler,
const std::function<void(void)> &other_error_handler,
const std::function<void(void)> &default_error_handler, const DebugInfoPtr &debug_info) {
try {
if (!main_func) {
MS_LOG(ERROR) << "The 'main_func' should not be empty.";
return;
}
main_func();
} catch (const py::error_already_set &ex) {
if (already_set_error_handler) {
already_set_error_handler();
}
// Re-throw this exception to Python interpreter to handle it
throw(py::error_already_set(ex));
} catch (const py::type_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::type_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::type_error(ss.str());
}
} catch (const py::value_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::value_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::value_error(ss.str());
}
} catch (const py::index_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::index_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::index_error(ss.str());
}
} catch (const py::key_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::key_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::key_error(ss.str());
}
} catch (const py::attribute_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::attribute_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::attribute_error(ss.str());
}
} catch (const py::name_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::name_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::name_error(ss.str());
}
} catch (const py::assertion_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::assertion_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::assertion_error(ss.str());
}
} catch (const py::base_exception &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::base_exception(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::base_exception(ss.str());
}
} catch (const py::keyboard_interrupt &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::keyboard_interrupt(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::keyboard_interrupt(ss.str());
}
} catch (const py::stop_iteration &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::stop_iteration(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::stop_iteration(ss.str());
}
} catch (const py::overflow_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::overflow_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::overflow_error(ss.str());
}
} catch (const py::zero_division_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::zero_division_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::zero_division_error(ss.str());
}
} catch (const py::environment_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::environment_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::environment_error(ss.str());
}
} catch (const py::io_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::io_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::io_error(ss.str());
}
} catch (const py::os_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::os_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::os_error(ss.str());
}
} catch (const py::memory_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::memory_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::memory_error(ss.str());
}
} catch (const py::unbound_local_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::unbound_local_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::unbound_local_error(ss.str());
}
} catch (const py::not_implemented_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::not_implemented_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::not_implemented_error(ss.str());
}
} catch (const py::indentation_error &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::indentation_error(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::indentation_error(ss.str());
}
} catch (const py::runtime_warning &ex) {
if (other_error_handler) {
other_error_handler();
}
if (debug_info == nullptr) {
throw py::runtime_warning(ex);
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw py::runtime_warning(ss.str());
}
} catch (const std::exception &ex) {
if (other_error_handler) {
other_error_handler();
}
// Re-throw this exception to Python interpreter to handle it.
if (debug_info == nullptr) {
throw std::runtime_error(ex.what());
} else {
std::stringstream ss;
ss << ex.what() << ".\n\n" << trace::GetDebugInfo(debug_info);
throw std::runtime_error(ss.str());
}
} catch (...) {
if (default_error_handler) {
default_error_handler();
}
#ifndef _MSC_VER
auto exception_type = abi::__cxa_current_exception_type();
MS_EXCEPTION_IF_NULL(exception_type);
std::string ex_name(exception_type->name());
MS_LOG(EXCEPTION) << "Error occurred. Exception name: " << ex_name;
#else
MS_LOG(EXCEPTION) << "Error occurred.";
#endif
}
}
} // namespace mindspore

View File

@ -15,8 +15,10 @@
*/
#include "include/common/utils/utils.h"
#include <set>
#include <string>
namespace mindspore {
bool IsOneOfPosteriorOperator(const std::string &name) {
const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};

View File

@ -815,7 +815,7 @@ def eval_script(exp_str, params):
except Exception as e:
error_info = f"When eval '{exp_str}' by using JIT Fallback feature, an error occurred: " + str(e) + \
". You can try to turn off JIT Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'."
logger.error(error_info)
logger.debug(error_info)
raise e
# Convert set to tuple.

View File

@ -873,3 +873,67 @@ def test_parser_fallback_nested_class_outer_grad():
net = NestedNet()
output = ops.grad(net)(mutable(x), y)
assert output == 0
class UserDefinedNet:
def __init__(self):
self.value = 10
def __call__(self, x):
return self.value * x
class UserDefinedMsFunctionCallNet:
def __init__(self):
self.value = 10
@ms.jit
def __call__(self, x):
return self.value * x
class UNet(ms.nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
def construct(self, x):
out = x * x
out = self.net(x)
out = out + out
return out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_resolve_cust_class():
"""
Feature: Syntax resolve.
Description: Graph syntax resolve support custom class input.
Expectation: No error.
"""
net = UNet(UserDefinedNet())
x = np.array([10], np.float32)
output = net(ms.Tensor(x))
print(output) # The output should == 200, but failed, check later.
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_resolve_cust_ms_function_call_class():
"""
Feature: Syntax resolve.
Description: Graph syntax resolve support custom class input.
Expectation: No error.
"""
net = UNet(UserDefinedMsFunctionCallNet())
x = np.array([10, 10], np.float32)
with pytest.raises(RuntimeError) as err:
net(ms.Tensor(x))
assert "Nested execution during JIT execution is not supported." in str(err.value)