From 9c92998a796b34a45021f283487a08c76f26ed90 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Tue, 7 Mar 2023 17:18:23 +0800 Subject: [PATCH] Support to return isolated node in control flow branches with raise node. --- .../frontend/optimizer/fallback_rewriter.cc | 27 ++ .../pipeline/jit/static_analysis/evaluator.cc | 7 + .../pipeline/jit/static_analysis/prim.cc | 24 +- .../jit/static_analysis/stack_frame.cc | 9 +- .../jit/static_analysis/static_analysis.cc | 410 +++++++++--------- .../jit/static_analysis/static_analysis.h | 9 +- mindspore/ccsrc/pipeline/jit/validator.cc | 25 +- tests/st/fallback/test_graph_fallback_none.py | 67 ++- .../test_graph_raise_with_variable.py | 30 +- 9 files changed, 349 insertions(+), 259 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/fallback_rewriter.cc b/mindspore/ccsrc/frontend/optimizer/fallback_rewriter.cc index 172a41f14d7..e8feb0df0c5 100644 --- a/mindspore/ccsrc/frontend/optimizer/fallback_rewriter.cc +++ b/mindspore/ccsrc/frontend/optimizer/fallback_rewriter.cc @@ -612,6 +612,25 @@ class CleanAfterOptARewriter : public BaseRewriter { : BaseRewriter(root_graph, manager) {} ~CleanAfterOptARewriter() override = default; + void UpdateAbstracts() { + const auto &nodes = manager_->all_nodes(); + for (const auto &node : nodes) { + const auto &abs = node->abstract(); + if (abs == nullptr) { + continue; + } + // Set flag for convert AbstractNone(PyExecute) to AbstractTensor in next renormalize. + if (IsPrimitiveCNode(node, prim::kPrimPyExecute) && abs->isa()) { + constexpr auto data_type = "__py_execute_no_return_type__"; + if (node->has_user_data(data_type)) { + auto type = std::make_shared(); + node->set_user_data(data_type, type); + set_need_renormalized(true); + } + } + } + } + protected: // From: // MakeSparseTensor(indices, values, dense_shape) @@ -897,6 +916,13 @@ class CleanAfterOptARewriter : public BaseRewriter { AnfNodePtr none_execute_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), script_node, none_tuple_node, none_tuple_node}); MS_LOG(DEBUG) << "none_execute_node:" << none_execute_node->DebugString(); + + // Keep AbstractNone for PyExecute, because the control flow join problem. + auto none_type = std::make_shared(); + none_execute_node->set_user_data("__py_execute_no_return_type__", none_type); + AbstractBasePtr res = std::make_shared(); + res->set_value(kAnyValue); + none_execute_node->set_abstract(res); return none_execute_node; } @@ -1051,6 +1077,7 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const pipeline::ResourcePtr &resou CleanAfterOptARewriter rewriter(root, manager); bool change = rewriter.Execute(); // Renormalize for new PyExecute node. + rewriter.UpdateAbstracts(); if (rewriter.need_renormalized()) { abstract::AbstractBasePtrList new_args_spec; std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec), diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 0c280a1084f..c1ea31da7cf 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -225,6 +225,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0"); if (enable_eliminate_unused_element) { const auto &cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); const auto &maybe_func = engine->GetCNodeOperatorAbstract(cnode, context, fg); if (maybe_func->isa() || maybe_func->isa()) { @@ -232,6 +233,12 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(engine, fg, cnode, abs_func_graph, context); } } + if (engine->check_isolated_side_effect() && node_eval_result->has_isolated_side_effect()) { + const auto &cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + cnode->set_has_isolated_side_effect_node(true); + fg->set_has_isolated_side_effect_node(true); + } MS_LOG(DEBUG) << "No need to jump as found result from cache for node_config"; } else { node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 875be7c4d64..509954af376 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -2064,12 +2064,16 @@ EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abst MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list; // when return value should be none - if (current_interpret_node->has_user_data("__py_execute_no_return_type__")) { - AbstractBasePtr res = std::make_shared(); - res->set_value(kAnyValue); - auto infer_result = std::make_shared(res, std::make_shared()); - evaluator_cache_mgr_->SetValue(args_abs_list, infer_result); - return infer_result; + constexpr auto data_type = "__py_execute_no_return_type__"; + if (current_interpret_node->has_user_data(data_type)) { + auto type = current_interpret_node->user_data(data_type); + if (type->isa()) { + AbstractBasePtr res = std::make_shared(); + res->set_value(kAnyValue); + auto infer_result = std::make_shared(res, std::make_shared()); + evaluator_cache_mgr_->SetValue(args_abs_list, infer_result); + return infer_result; + } } TypePtr type = kFloat64; if (current_interpret_node->has_user_data("__py_execute_tensor_type__")) { @@ -2811,6 +2815,14 @@ class RaiseEvaluator : public TransitionPrimEvaluator { auto none_type = std::make_shared(); raise_error_node->set_user_data("__py_execute_no_return_type__", none_type); cur_graph->ReplaceInOrder(node, raise_error_node); + + // Set isolated side effect flag for raise node. + const auto &manager = cur_graph->manager(); + manager->Replace(node, raise_error_node); + raise_error_node->set_has_isolated_side_effect_node(true); + cur_graph->set_has_isolated_side_effect_node(true); + MS_LOG(DEBUG) << "Found Side Effect Primitive CNode: " << raise_error_node->DebugString(); + AnalysisEnginePtr eng = out_conf->engine(); MS_EXCEPTION_IF_NULL(eng); AnfNodeConfigPtr fn_conf = eng->MakeConfig(raise_error_node, out_conf->context(), out_conf->func_graph()); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc index 8b80b71a528..079264a8f2a 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021-2022 Huawei Technologies Co., Ltd + * Copyright 2021-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -199,6 +199,13 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) { node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf); } else { node_eval_result = engine->ObtainEvalResultWithCache(node_conf); + MS_EXCEPTION_IF_NULL(node_eval_result); + if (engine->check_isolated_side_effect() && node_eval_result->has_isolated_side_effect()) { + const auto &cnode = current_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + cnode->set_has_isolated_side_effect_node(true); + current_context_->func_graph()->set_has_isolated_side_effect_node(true); + } } MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString() << ") = " << (node_eval_result->abstract() ? node_eval_result->abstract()->ToString() : "Abstract null"); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 82f2ed91ba7..79fe95034cd 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -64,6 +64,208 @@ void DecreaseStackFrameDepth() { } size_t StackFrameDepth() { return stack_frame_depth; } +namespace { +bool NeedWaitForBranches(const AbstractBasePtr &abstract) { + MS_EXCEPTION_IF_NULL(abstract); + if (abstract->isa()) { + return true; + } + if (abstract->isa()) { + auto elements = abstract->cast_ptr()->elements(); + if (std::any_of(elements.begin(), elements.end(), + [](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) { + return true; + } + } + return false; +} + +void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf, + std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main, + AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals, + trace::TraceCNodeEvalStack trace_c_node_evals) { + AnalysisSchedule::set_thread_id(thread_id); + // Restore trace stack for dump stack when there is exception. + trace::TraceEvalCNodeStackPrepare(trace_c_node_evals); + trace_c_node_evals.clear(); + trace::TraceGraphEvalStackPrepare(graph_evals); + graph_evals.clear(); + + try { + // Wait for Signal to run + MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " waiting."; + (void)async_task->GetResult(); + MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " running."; + + // Acquire GIL for eval to callback python. + EvalResultPtr result; + { + MS_LOG(DEBUG) << std::this_thread::get_id() << " begin."; + py::gil_scoped_acquire py_guard; + result = eval->Run(engine, args_conf_list, out_conf); + } + MS_LOG(DEBUG) << std::this_thread::get_id() << " end."; + MS_EXCEPTION_IF_NULL(result); + MS_EXCEPTION_IF_NULL(result->abstract()); + + // Check the branch value to be compatible with the other branch value. + AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract()); + // Broaden the result of switch(c,t,f)() + auto broaden_abstract = result->abstract()->Broaden(); + // Notify the thread of waiting for branch value and the main thread to continue. + async_result_branch->set_result(broaden_abstract); + async_result_main->set_result(broaden_abstract); + MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString() + << " asyncResult address = " << async_result_branch.get() + << " value = " << async_result_branch->TryGetResult()->ToString(); + } catch (const std::exception &ex) { + MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception."; + AnalysisSchedule::GetInstance().HandleException(ex); + } + trace::ClearTraceStack(); + ClearThreadLocal(); + MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " exited."; + // Thread number will be drop when thread exits. + AnalysisSchedule::GetInstance().DecreaseThreadCount(); +} + +AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs, + const std::vector &pending_async_abstract_list, + const std::vector &index) { + auto sequence_abs = dyn_cast_ptr(orig_abs); + if (sequence_abs != nullptr) { + const auto &orig_elements = sequence_abs->elements(); + AbstractBasePtrList new_elements; + for (size_t i = 0; i < orig_elements.size(); ++i) { + if (orig_elements[i]->isa()) { + AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast()}; + for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) { + std::vector new_index(index); + new_index.push_back(i); + auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], new_index); + abs_func_list.push_back(async_func); + } + new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list)); + } else if (orig_elements[i]->isa()) { + std::vector new_index(index); + new_index.push_back(i); + new_elements.push_back(BuildAsyncAbstractRecursively(orig_elements[i], pending_async_abstract_list, new_index)); + } else { + new_elements.push_back(orig_elements[i]); + } + } + static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0"); + AbstractBasePtr new_abs; + if (orig_abs->isa()) { + new_abs = std::make_shared( + new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr)); + } else if (orig_abs->isa()) { + new_abs = std::make_shared( + new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr)); + } else { + MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString(); + } + return new_abs; + } + MS_LOG(EXCEPTION) << "Orig abstract is not AbstractTuple or AbstractList, but: " << orig_abs->ToString(); +} + +void BuildPossibleSpecs(const AbstractBasePtr &first_result, + const std::vector &branch_async_abstract_list, + AbstractBasePtrList *out_specs) { + MS_EXCEPTION_IF_NULL(out_specs); + MS_EXCEPTION_IF_NULL(first_result); + std::vector pending_async_abstract_list; + std::size_t len = branch_async_abstract_list.size(); + + for (size_t i = 0; i < len; ++i) { + auto result = branch_async_abstract_list[i]->TryGetResult(); + if (result) { + out_specs->push_back(result); + } else { + pending_async_abstract_list.push_back(branch_async_abstract_list[i]); + } + } + if (first_result->isa()) { + for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) { + auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector{0}); + out_specs->push_back(async_func); + } + } else if (first_result->isa()) { + const auto &new_first_result = + BuildAsyncAbstractRecursively(first_result, pending_async_abstract_list, std::vector()); + MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString() + << ", new: " << new_first_result->ToString(); + std::replace_if( + out_specs->begin(), out_specs->end(), [first_result](const auto &element) { return element == first_result; }, + new_first_result); + } else { + MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result"; + } +} + +void CheckInterpretedObject(const AbstractBasePtr &abs) { + static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK"); + static const auto use_fallback = (support_fallback != "0"); + if (!use_fallback) { + return; + } + auto value = abs->BuildValue(); + if (value->isa()) { + MS_LOG(ERROR) << "Do not support " << value->ToString() << ". " + << "\nIf you are using third-party modules, you can try setting: " + << "'export MS_DEV_SUPPORT_MODULES=module1,module2,...'."; + } +} + +EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) { + auto val = abs->BuildValue(); + auto class_val = dyn_cast_ptr(val); + const auto &class_name = class_val->name(); + py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); + auto py_fn = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name)); + if (py::isinstance(py_fn)) { + MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << "."; + MS_LOG(ERROR) << "It's called at: " << cnode->DebugString(); + MS_EXCEPTION(ValueError) << "Can not call " << class_name << " to create python object in graph mode. " + << "Try using 'jit_class' to decorate the class?"; + } + auto list_func_fg = parse::ParsePythonCode(py_fn); + auto fg = cnode->func_graph(); + list_func_fg->set_manager(fg->manager()); + + auto &inputs = cnode->inputs(); + std::vector new_cnode_inputs; + (void)new_cnode_inputs.emplace_back(NewValueNode(list_func_fg)); + for (std::size_t i = 1; i < inputs.size(); ++i) { + (void)new_cnode_inputs.emplace_back(inputs[i]); + } + auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs); + fg->ReplaceInOrder(cnode, new_cnode); + + AnalysisEnginePtr eng = conf->engine(); + MS_EXCEPTION_IF_NULL(eng); + AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph()); + return eng->ForwardConfig(conf, fn_conf); +} + +bool CheckFuncIsolatedSideEffect(const AbstractFunctionPtr &func, bool check_isolated_side_effect) { + // Check if func graph contains isolated side-effect, and sync. + if (check_isolated_side_effect) { + auto func_graph_abs = dyn_cast_ptr(func); + if (func_graph_abs != nullptr) { + return func_graph_abs->func_graph()->has_isolated_side_effect_node(); + } else { + auto meta_func_graph_abs = dyn_cast_ptr(func); + if (meta_func_graph_abs != nullptr) { + return meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node(); + } + } + } + return false; +} +} // namespace + EvalResultPtr PrimitiveEvalCache::Get(const PrimitivePtr &prim, const AbstractBasePtrList &args) const { MS_EXCEPTION_IF_NULL(prim); std::lock_guard guard(mutex_); @@ -407,51 +609,6 @@ AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, return possible_func; } -void CheckInterpretedObject(const AbstractBasePtr &abs) { - static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK"); - static const auto use_fallback = (support_fallback != "0"); - if (!use_fallback) { - return; - } - auto value = abs->BuildValue(); - if (value->isa()) { - MS_LOG(ERROR) << "Do not support " << value->ToString() << ". " - << "\nIf you are using third-party modules, you can try setting: " - << "'export MS_DEV_SUPPORT_MODULES=module1,module2,...'."; - } -} - -EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) { - auto val = abs->BuildValue(); - auto class_val = dyn_cast_ptr(val); - const auto &class_name = class_val->name(); - py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); - auto py_fn = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name)); - if (py::isinstance(py_fn)) { - MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << "."; - MS_LOG(ERROR) << "It's called at: " << cnode->DebugString(); - MS_EXCEPTION(ValueError) << "Can not call " << class_name << " to create python object in graph mode. " - << "Try using 'jit_class' to decorate the class?"; - } - auto list_func_fg = parse::ParsePythonCode(py_fn); - auto fg = cnode->func_graph(); - list_func_fg->set_manager(fg->manager()); - - auto &inputs = cnode->inputs(); - std::vector new_cnode_inputs; - (void)new_cnode_inputs.emplace_back(NewValueNode(list_func_fg)); - for (std::size_t i = 1; i < inputs.size(); ++i) { - (void)new_cnode_inputs.emplace_back(inputs[i]); - } - auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs); - fg->ReplaceInOrder(cnode, new_cnode); - - AnalysisEnginePtr eng = conf->engine(); - MS_EXCEPTION_IF_NULL(eng); - AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph()); - return eng->ForwardConfig(conf, fn_conf); -} - EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(cnode); @@ -524,19 +681,21 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list); // Check if func graph contains isolated side-effect, and sync. if (check_isolated_side_effect()) { - auto func_graph_abs = mindspore::cast(func); - if (func_graph_abs != nullptr) { - contains_isolated_side_effect = - contains_isolated_side_effect || func_graph_abs->func_graph()->has_isolated_side_effect_node(); - } - auto meta_func_graph_abs = mindspore::cast(func); - if (meta_func_graph_abs != nullptr) { - contains_isolated_side_effect = - contains_isolated_side_effect || meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node(); - } + func->Visit([this, &contains_isolated_side_effect](const AbstractFuncAtomPtr &poss) { + MS_EXCEPTION_IF_NULL(poss); + auto resolved_atom = poss; + auto async_abs_func = poss->cast_ptr(); + if (async_abs_func != nullptr) { + auto resolved_func = async_abs_func->GetUnique(); + resolved_atom = dyn_cast(resolved_func); + MS_EXCEPTION_IF_NULL(resolved_atom); + } + contains_isolated_side_effect |= CheckFuncIsolatedSideEffect(resolved_atom, check_isolated_side_effect()); + }); if (contains_isolated_side_effect) { cnode->set_has_isolated_side_effect_node(true); conf->func_graph()->set_has_isolated_side_effect_node(true); + eval_result->set_has_isolated_side_effect(true); } } return eval_result; @@ -996,147 +1155,6 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_ return std::make_shared(joined_spec, std::make_shared()); } -namespace { -bool NeedWaitForBranches(const AbstractBasePtr &abstract) { - MS_EXCEPTION_IF_NULL(abstract); - if (abstract->isa()) { - return true; - } - if (abstract->isa()) { - auto elements = abstract->cast_ptr()->elements(); - if (std::any_of(elements.begin(), elements.end(), - [](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) { - return true; - } - } - return false; -} - -void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf, - std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main, - AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals, - trace::TraceCNodeEvalStack trace_c_node_evals) { - AnalysisSchedule::set_thread_id(thread_id); - // Restore trace stack for dump stack when there is exception. - trace::TraceEvalCNodeStackPrepare(trace_c_node_evals); - trace_c_node_evals.clear(); - trace::TraceGraphEvalStackPrepare(graph_evals); - graph_evals.clear(); - - try { - // Wait for Signal to run - MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " waiting."; - (void)async_task->GetResult(); - MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " running."; - - // Acquire GIL for eval to callback python. - EvalResultPtr result; - { - MS_LOG(DEBUG) << std::this_thread::get_id() << " begin."; - py::gil_scoped_acquire py_guard; - result = eval->Run(engine, args_conf_list, out_conf); - } - MS_LOG(DEBUG) << std::this_thread::get_id() << " end."; - MS_EXCEPTION_IF_NULL(result); - MS_EXCEPTION_IF_NULL(result->abstract()); - - // Check the branch value to be compatible with the other branch value. - AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract()); - // Broaden the result of switch(c,t,f)() - auto broaden_abstract = result->abstract()->Broaden(); - // Notify the thread of waiting for branch value and the main thread to continue. - async_result_branch->set_result(broaden_abstract); - async_result_main->set_result(broaden_abstract); - MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString() - << " asyncResult address = " << async_result_branch.get() - << " value = " << async_result_branch->TryGetResult()->ToString(); - } catch (const std::exception &ex) { - MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception."; - AnalysisSchedule::GetInstance().HandleException(ex); - } - trace::ClearTraceStack(); - ClearThreadLocal(); - MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " exited."; - // Thread number will be drop when thread exits. - AnalysisSchedule::GetInstance().DecreaseThreadCount(); -} - -AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs, - const std::vector &pending_async_abstract_list, - const std::vector &index) { - auto sequence_abs = dyn_cast_ptr(orig_abs); - if (sequence_abs != nullptr) { - const auto &orig_elements = sequence_abs->elements(); - AbstractBasePtrList new_elements; - for (size_t i = 0; i < orig_elements.size(); ++i) { - if (orig_elements[i]->isa()) { - AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast()}; - for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) { - std::vector new_index(index); - new_index.push_back(i); - auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], new_index); - abs_func_list.push_back(async_func); - } - new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list)); - } else if (orig_elements[i]->isa()) { - std::vector new_index(index); - new_index.push_back(i); - new_elements.push_back(BuildAsyncAbstractRecursively(orig_elements[i], pending_async_abstract_list, new_index)); - } else { - new_elements.push_back(orig_elements[i]); - } - } - static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0"); - AbstractBasePtr new_abs; - if (orig_abs->isa()) { - new_abs = std::make_shared( - new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr)); - } else if (orig_abs->isa()) { - new_abs = std::make_shared( - new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr)); - } else { - MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString(); - } - return new_abs; - } - MS_LOG(EXCEPTION) << "Orig abstract is not AbstractTuple or AbstractList, but: " << orig_abs->ToString(); -} - -void BuildPossibleSpecs(const AbstractBasePtr &first_result, - const std::vector &branch_async_abstract_list, - AbstractBasePtrList *out_specs) { - MS_EXCEPTION_IF_NULL(out_specs); - MS_EXCEPTION_IF_NULL(first_result); - std::vector pending_async_abstract_list; - std::size_t len = branch_async_abstract_list.size(); - - for (size_t i = 0; i < len; ++i) { - auto result = branch_async_abstract_list[i]->TryGetResult(); - if (result) { - out_specs->push_back(result); - } else { - pending_async_abstract_list.push_back(branch_async_abstract_list[i]); - } - } - if (first_result->isa()) { - for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) { - auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector{0}); - out_specs->push_back(async_func); - } - } else if (first_result->isa()) { - const auto &new_first_result = - BuildAsyncAbstractRecursively(first_result, pending_async_abstract_list, std::vector()); - MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString() - << ", new: " << new_first_result->ToString(); - std::replace_if( - out_specs->begin(), out_specs->end(), [first_result](const auto &element) { return element == first_result; }, - new_first_result); - } else { - MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result"; - } -} -} // namespace - EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 9576bf5731e..ce61cb23be9 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -59,16 +59,23 @@ using AttrValueMapPtr = std::shared_ptr; // the class to save evaluated result: abstract value and modified attribute class EvalResult : public Base { public: - EvalResult(const AbstractBasePtr &abs, const AttrValueMapPtr &attr) : abstract_(abs), attribute_(attr) {} + EvalResult(const AbstractBasePtr &abs, const AttrValueMapPtr &attr) + : abstract_(abs), attribute_(attr), has_isolated_side_effect_(false) {} ~EvalResult() override = default; MS_DECLARE_PARENT(EvalResult, Base); const AbstractBasePtr &abstract() const { return abstract_; } const AttrValueMapPtr &attribute() const { return attribute_; } + bool has_isolated_side_effect() const { return has_isolated_side_effect_; } + void set_has_isolated_side_effect(bool has_isolated_side_effect) { + has_isolated_side_effect_ = has_isolated_side_effect; + } private: AbstractBasePtr abstract_; // Attribute related to PrimEvaluator; AttrValueMapPtr attribute_; + + bool has_isolated_side_effect_; }; using EvalResultPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 41cfe1fb4b1..d009b75959b 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019-2022 Huawei Technologies Co., Ltd + * Copyright 2019-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -101,24 +101,6 @@ bool CheckAbstractScalar(const AnfNodePtr &node) { return false; } -bool CheckIfRaise(const AnfNodePtr &node) { - if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) { - auto cnode = node->cast(); - auto inputs = cnode->inputs(); - auto first = inputs[1]; - auto script_node = first->cast(); - if (script_node->value()->isa()) { - auto script = GetValueNode(script_node)->value(); - std::string raise_script = "raise_func"; - auto idx = script.find(raise_script); - if (idx != string::npos) { - return true; - } - } - } - return false; -} - void ValidateAbstract(const AnfNodePtr &node) { if (node == nullptr) { MS_LOG(DEBUG) << "Node to validate is invalid"; @@ -141,11 +123,6 @@ void ValidateAbstract(const AnfNodePtr &node) { MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString(); return; } - if (CheckIfRaise(node)) { - ShapeVector shp{abstract::Shape::kShapeRankAny}; - auto abs = std::make_shared(kFloat64, std::make_shared(shp)); - node->set_abstract(abs); - } bool is_legal_abstract = abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || diff --git a/tests/st/fallback/test_graph_fallback_none.py b/tests/st/fallback/test_graph_fallback_none.py index 9c31e3d86c8..65d3d08ed8b 100644 --- a/tests/st/fallback/test_graph_fallback_none.py +++ b/tests/st/fallback/test_graph_fallback_none.py @@ -167,7 +167,6 @@ def test_none_is_default_value_of_parameter(): assert res1 is None -@pytest.mark.skip(reason="No support print x side effect.") @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_arm_ascend_training @@ -191,7 +190,7 @@ def test_none_is_default_value_of_parameter_2(): x = [1, 2] y = [3, 4] res2 = foo(x, y) - assert res2 == (3, 4) + assert res2 == [3, 4] @pytest.mark.level0 @@ -463,7 +462,6 @@ def test_none_is_input_of_tuple_return_2(): assert out_me_graph == out_me_pynative -@pytest.mark.skip(reason="No support print side effect.") @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_arm_ascend_training @@ -489,3 +487,66 @@ def test_none_is_return_of_sub_graph_control_flow(): data = Tensor(np.ones([2, 3]), dtype=ms.float32) out = net(data) assert (out.asnumpy() == data).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_none_is_return_of_sub_graph_control_flow_raise(): + """ + Feature: Support None. + Description: Support None is the return of sub_graph in control flow. And Raise node is in sub_graph. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def inner_func(self, x): # pylint: disable=R1711 + if x == 2: + raise ValueError("The input should not be ", x) + return None + + def construct(self, x): + self.inner_func(x) + return x + + net = RaiseNet() + res = net(Tensor(1)) + print("res:", res) + assert res.asnumpy() == 1 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_none_is_return_raise(): + """ + Feature: Support None. + Description: Support None is the return of sub_graph in control flow. And Raise node is in sub_graph. + Expectation: No exception. + """ + def check_test(shp): # pylint: disable=R1711 + if shp[0] > 5: + raise ValueError('raise value error.') + return None + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.one = Tensor(1, dtype=ms.float32) + + def construct(self, x): + shp = x.shape + check_test(shp) + return x + + with pytest.raises(ValueError, match="raise value error."): + np_data = np.random.randint(6, size=(6,)) + data = Tensor(np_data, dtype=ms.float32) + dyn_tensor = Tensor(shape=[None], dtype=ms.float32) + net = Net() + net.set_inputs(dyn_tensor) + out = net(data) + assert out == data diff --git a/tests/st/graph_syntax/statements/test_graph_raise_with_variable.py b/tests/st/graph_syntax/statements/test_graph_raise_with_variable.py index 93925bcb73d..2166e4a03b0 100644 --- a/tests/st/graph_syntax/statements/test_graph_raise_with_variable.py +++ b/tests/st/graph_syntax/statements/test_graph_raise_with_variable.py @@ -228,10 +228,10 @@ def test_raise_with_variable_control_flow2(): Expectation: No exception. """ class RaiseNet(nn.Cell): - def construct(self, x, y): + def construct(self, x, y): # pylint: disable=R1711 if x == y: raise RuntimeError(f"The input should not be {x}.") - return x + return None with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor: net = RaiseNet() @@ -243,32 +243,6 @@ def test_raise_with_variable_control_flow2(): raise_info_joinedstr_tensor.value) -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_raise_with_variable_control_flow3(): - """ - Feature: graph raise by JIT Fallback. - Description: Test raise. - Expectation: No exception. - """ - class RaiseNet(nn.Cell): - def construct(self, x, y, z): - if x == y: - raise RuntimeError(f"The input should not be {x}.") - return z - - with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor: - net = RaiseNet() - x = Tensor(1) - y = Tensor(1) - z = (x, y) - res = net(x, y, z) - print("res:", res) - assert "The input should not be 1" in str( - raise_info_joinedstr_tensor.value) - - @pytest.mark.skip(reason="not support yet") @pytest.mark.level0 @pytest.mark.platform_x86_cpu