Support to return isolated node in control flow branches with raise node.
This commit is contained in:
parent
f06617b700
commit
9c92998a79
|
@ -612,6 +612,25 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
: BaseRewriter(root_graph, manager) {}
|
||||
~CleanAfterOptARewriter() override = default;
|
||||
|
||||
void UpdateAbstracts() {
|
||||
const auto &nodes = manager_->all_nodes();
|
||||
for (const auto &node : nodes) {
|
||||
const auto &abs = node->abstract();
|
||||
if (abs == nullptr) {
|
||||
continue;
|
||||
}
|
||||
// Set flag for convert AbstractNone(PyExecute) to AbstractTensor in next renormalize.
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPyExecute) && abs->isa<abstract::AbstractNone>()) {
|
||||
constexpr auto data_type = "__py_execute_no_return_type__";
|
||||
if (node->has_user_data(data_type)) {
|
||||
auto type = std::make_shared<TypeAnything>();
|
||||
node->set_user_data<Type>(data_type, type);
|
||||
set_need_renormalized(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// From:
|
||||
// MakeSparseTensor(indices, values, dense_shape)
|
||||
|
@ -897,6 +916,13 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
AnfNodePtr none_execute_node =
|
||||
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), script_node, none_tuple_node, none_tuple_node});
|
||||
MS_LOG(DEBUG) << "none_execute_node:" << none_execute_node->DebugString();
|
||||
|
||||
// Keep AbstractNone for PyExecute, because the control flow join problem.
|
||||
auto none_type = std::make_shared<TypeNone>();
|
||||
none_execute_node->set_user_data<Type>("__py_execute_no_return_type__", none_type);
|
||||
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
|
||||
res->set_value(kAnyValue);
|
||||
none_execute_node->set_abstract(res);
|
||||
return none_execute_node;
|
||||
}
|
||||
|
||||
|
@ -1051,6 +1077,7 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const pipeline::ResourcePtr &resou
|
|||
CleanAfterOptARewriter rewriter(root, manager);
|
||||
bool change = rewriter.Execute();
|
||||
// Renormalize for new PyExecute node.
|
||||
rewriter.UpdateAbstracts();
|
||||
if (rewriter.need_renormalized()) {
|
||||
abstract::AbstractBasePtrList new_args_spec;
|
||||
std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),
|
||||
|
|
|
@ -225,6 +225,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
|
|||
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
|
||||
if (enable_eliminate_unused_element) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &maybe_func = engine->GetCNodeOperatorAbstract(cnode, context, fg);
|
||||
if (maybe_func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
|
||||
maybe_func->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
|
@ -232,6 +233,12 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
|
|||
SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(engine, fg, cnode, abs_func_graph, context);
|
||||
}
|
||||
}
|
||||
if (engine->check_isolated_side_effect() && node_eval_result->has_isolated_side_effect()) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_has_isolated_side_effect_node(true);
|
||||
fg->set_has_isolated_side_effect_node(true);
|
||||
}
|
||||
MS_LOG(DEBUG) << "No need to jump as found result from cache for node_config";
|
||||
} else {
|
||||
node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf);
|
||||
|
|
|
@ -2064,12 +2064,16 @@ EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abst
|
|||
MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
|
||||
|
||||
// when return value should be none
|
||||
if (current_interpret_node->has_user_data("__py_execute_no_return_type__")) {
|
||||
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
|
||||
res->set_value(kAnyValue);
|
||||
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
|
||||
return infer_result;
|
||||
constexpr auto data_type = "__py_execute_no_return_type__";
|
||||
if (current_interpret_node->has_user_data(data_type)) {
|
||||
auto type = current_interpret_node->user_data<Type>(data_type);
|
||||
if (type->isa<TypeNone>()) {
|
||||
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
|
||||
res->set_value(kAnyValue);
|
||||
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
}
|
||||
TypePtr type = kFloat64;
|
||||
if (current_interpret_node->has_user_data("__py_execute_tensor_type__")) {
|
||||
|
@ -2811,6 +2815,14 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
auto none_type = std::make_shared<TypeNone>();
|
||||
raise_error_node->set_user_data<Type>("__py_execute_no_return_type__", none_type);
|
||||
cur_graph->ReplaceInOrder(node, raise_error_node);
|
||||
|
||||
// Set isolated side effect flag for raise node.
|
||||
const auto &manager = cur_graph->manager();
|
||||
manager->Replace(node, raise_error_node);
|
||||
raise_error_node->set_has_isolated_side_effect_node(true);
|
||||
cur_graph->set_has_isolated_side_effect_node(true);
|
||||
MS_LOG(DEBUG) << "Found Side Effect Primitive CNode: " << raise_error_node->DebugString();
|
||||
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(raise_error_node, out_conf->context(), out_conf->func_graph());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -199,6 +199,13 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
|
|||
node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf);
|
||||
} else {
|
||||
node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
MS_EXCEPTION_IF_NULL(node_eval_result);
|
||||
if (engine->check_isolated_side_effect() && node_eval_result->has_isolated_side_effect()) {
|
||||
const auto &cnode = current_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_has_isolated_side_effect_node(true);
|
||||
current_context_->func_graph()->set_has_isolated_side_effect_node(true);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString() << ") = "
|
||||
<< (node_eval_result->abstract() ? node_eval_result->abstract()->ToString() : "Abstract null");
|
||||
|
|
|
@ -64,6 +64,208 @@ void DecreaseStackFrameDepth() {
|
|||
}
|
||||
size_t StackFrameDepth() { return stack_frame_depth; }
|
||||
|
||||
namespace {
|
||||
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
if (abstract->isa<AbstractSequence>()) {
|
||||
auto elements = abstract->cast_ptr<AbstractSequence>()->elements();
|
||||
if (std::any_of(elements.begin(), elements.end(),
|
||||
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
|
||||
std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
|
||||
AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals,
|
||||
trace::TraceCNodeEvalStack trace_c_node_evals) {
|
||||
AnalysisSchedule::set_thread_id(thread_id);
|
||||
// Restore trace stack for dump stack when there is exception.
|
||||
trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
|
||||
trace_c_node_evals.clear();
|
||||
trace::TraceGraphEvalStackPrepare(graph_evals);
|
||||
graph_evals.clear();
|
||||
|
||||
try {
|
||||
// Wait for Signal to run
|
||||
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " waiting.";
|
||||
(void)async_task->GetResult();
|
||||
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " running.";
|
||||
|
||||
// Acquire GIL for eval to callback python.
|
||||
EvalResultPtr result;
|
||||
{
|
||||
MS_LOG(DEBUG) << std::this_thread::get_id() << " begin.";
|
||||
py::gil_scoped_acquire py_guard;
|
||||
result = eval->Run(engine, args_conf_list, out_conf);
|
||||
}
|
||||
MS_LOG(DEBUG) << std::this_thread::get_id() << " end.";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
MS_EXCEPTION_IF_NULL(result->abstract());
|
||||
|
||||
// Check the branch value to be compatible with the other branch value.
|
||||
AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
|
||||
// Broaden the result of switch(c,t,f)()
|
||||
auto broaden_abstract = result->abstract()->Broaden();
|
||||
// Notify the thread of waiting for branch value and the main thread to continue.
|
||||
async_result_branch->set_result(broaden_abstract);
|
||||
async_result_main->set_result(broaden_abstract);
|
||||
MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString()
|
||||
<< " asyncResult address = " << async_result_branch.get()
|
||||
<< " value = " << async_result_branch->TryGetResult()->ToString();
|
||||
} catch (const std::exception &ex) {
|
||||
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
|
||||
AnalysisSchedule::GetInstance().HandleException(ex);
|
||||
}
|
||||
trace::ClearTraceStack();
|
||||
ClearThreadLocal();
|
||||
MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " exited.";
|
||||
// Thread number will be drop when thread exits.
|
||||
AnalysisSchedule::GetInstance().DecreaseThreadCount();
|
||||
}
|
||||
|
||||
AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
|
||||
const std::vector<AsyncAbstractPtr> &pending_async_abstract_list,
|
||||
const std::vector<std::size_t> &index) {
|
||||
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(orig_abs);
|
||||
if (sequence_abs != nullptr) {
|
||||
const auto &orig_elements = sequence_abs->elements();
|
||||
AbstractBasePtrList new_elements;
|
||||
for (size_t i = 0; i < orig_elements.size(); ++i) {
|
||||
if (orig_elements[i]->isa<AbstractFuncAtom>()) {
|
||||
AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast<AbstractFuncAtomPtr>()};
|
||||
for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
|
||||
std::vector<std::size_t> new_index(index);
|
||||
new_index.push_back(i);
|
||||
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], new_index);
|
||||
abs_func_list.push_back(async_func);
|
||||
}
|
||||
new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list));
|
||||
} else if (orig_elements[i]->isa<AbstractSequence>()) {
|
||||
std::vector<std::size_t> new_index(index);
|
||||
new_index.push_back(i);
|
||||
new_elements.push_back(BuildAsyncAbstractRecursively(orig_elements[i], pending_async_abstract_list, new_index));
|
||||
} else {
|
||||
new_elements.push_back(orig_elements[i]);
|
||||
}
|
||||
}
|
||||
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
|
||||
AbstractBasePtr new_abs;
|
||||
if (orig_abs->isa<AbstractTuple>()) {
|
||||
new_abs = std::make_shared<AbstractTuple>(
|
||||
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
|
||||
} else if (orig_abs->isa<AbstractList>()) {
|
||||
new_abs = std::make_shared<AbstractList>(
|
||||
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
|
||||
}
|
||||
return new_abs;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Orig abstract is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
|
||||
}
|
||||
|
||||
void BuildPossibleSpecs(const AbstractBasePtr &first_result,
|
||||
const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
|
||||
AbstractBasePtrList *out_specs) {
|
||||
MS_EXCEPTION_IF_NULL(out_specs);
|
||||
MS_EXCEPTION_IF_NULL(first_result);
|
||||
std::vector<AsyncAbstractPtr> pending_async_abstract_list;
|
||||
std::size_t len = branch_async_abstract_list.size();
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
auto result = branch_async_abstract_list[i]->TryGetResult();
|
||||
if (result) {
|
||||
out_specs->push_back(result);
|
||||
} else {
|
||||
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
|
||||
}
|
||||
}
|
||||
if (first_result->isa<AbstractFunction>()) {
|
||||
for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
|
||||
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector<size_t>{0});
|
||||
out_specs->push_back(async_func);
|
||||
}
|
||||
} else if (first_result->isa<AbstractSequence>()) {
|
||||
const auto &new_first_result =
|
||||
BuildAsyncAbstractRecursively(first_result, pending_async_abstract_list, std::vector<size_t>());
|
||||
MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString()
|
||||
<< ", new: " << new_first_result->ToString();
|
||||
std::replace_if(
|
||||
out_specs->begin(), out_specs->end(), [first_result](const auto &element) { return element == first_result; },
|
||||
new_first_result);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
|
||||
}
|
||||
}
|
||||
|
||||
void CheckInterpretedObject(const AbstractBasePtr &abs) {
|
||||
static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
|
||||
static const auto use_fallback = (support_fallback != "0");
|
||||
if (!use_fallback) {
|
||||
return;
|
||||
}
|
||||
auto value = abs->BuildValue();
|
||||
if (value->isa<parse::InterpretedObject>()) {
|
||||
MS_LOG(ERROR) << "Do not support " << value->ToString() << ". "
|
||||
<< "\nIf you are using third-party modules, you can try setting: "
|
||||
<< "'export MS_DEV_SUPPORT_MODULES=module1,module2,...'.";
|
||||
}
|
||||
}
|
||||
|
||||
EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) {
|
||||
auto val = abs->BuildValue();
|
||||
auto class_val = dyn_cast_ptr<parse::ClassType>(val);
|
||||
const auto &class_name = class_val->name();
|
||||
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||
auto py_fn = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name));
|
||||
if (py::isinstance<py::none>(py_fn)) {
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << ".";
|
||||
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
|
||||
MS_EXCEPTION(ValueError) << "Can not call " << class_name << " to create python object in graph mode. "
|
||||
<< "Try using 'jit_class' to decorate the class?";
|
||||
}
|
||||
auto list_func_fg = parse::ParsePythonCode(py_fn);
|
||||
auto fg = cnode->func_graph();
|
||||
list_func_fg->set_manager(fg->manager());
|
||||
|
||||
auto &inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_cnode_inputs;
|
||||
(void)new_cnode_inputs.emplace_back(NewValueNode(list_func_fg));
|
||||
for (std::size_t i = 1; i < inputs.size(); ++i) {
|
||||
(void)new_cnode_inputs.emplace_back(inputs[i]);
|
||||
}
|
||||
auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
|
||||
fg->ReplaceInOrder(cnode, new_cnode);
|
||||
|
||||
AnalysisEnginePtr eng = conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
|
||||
return eng->ForwardConfig(conf, fn_conf);
|
||||
}
|
||||
|
||||
bool CheckFuncIsolatedSideEffect(const AbstractFunctionPtr &func, bool check_isolated_side_effect) {
|
||||
// Check if func graph contains isolated side-effect, and sync.
|
||||
if (check_isolated_side_effect) {
|
||||
auto func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(func);
|
||||
if (func_graph_abs != nullptr) {
|
||||
return func_graph_abs->func_graph()->has_isolated_side_effect_node();
|
||||
} else {
|
||||
auto meta_func_graph_abs = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(func);
|
||||
if (meta_func_graph_abs != nullptr) {
|
||||
return meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node();
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
EvalResultPtr PrimitiveEvalCache::Get(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
|
@ -407,51 +609,6 @@ AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode,
|
|||
return possible_func;
|
||||
}
|
||||
|
||||
void CheckInterpretedObject(const AbstractBasePtr &abs) {
|
||||
static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
|
||||
static const auto use_fallback = (support_fallback != "0");
|
||||
if (!use_fallback) {
|
||||
return;
|
||||
}
|
||||
auto value = abs->BuildValue();
|
||||
if (value->isa<parse::InterpretedObject>()) {
|
||||
MS_LOG(ERROR) << "Do not support " << value->ToString() << ". "
|
||||
<< "\nIf you are using third-party modules, you can try setting: "
|
||||
<< "'export MS_DEV_SUPPORT_MODULES=module1,module2,...'.";
|
||||
}
|
||||
}
|
||||
|
||||
EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) {
|
||||
auto val = abs->BuildValue();
|
||||
auto class_val = dyn_cast_ptr<parse::ClassType>(val);
|
||||
const auto &class_name = class_val->name();
|
||||
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||
auto py_fn = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name));
|
||||
if (py::isinstance<py::none>(py_fn)) {
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << ".";
|
||||
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
|
||||
MS_EXCEPTION(ValueError) << "Can not call " << class_name << " to create python object in graph mode. "
|
||||
<< "Try using 'jit_class' to decorate the class?";
|
||||
}
|
||||
auto list_func_fg = parse::ParsePythonCode(py_fn);
|
||||
auto fg = cnode->func_graph();
|
||||
list_func_fg->set_manager(fg->manager());
|
||||
|
||||
auto &inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_cnode_inputs;
|
||||
(void)new_cnode_inputs.emplace_back(NewValueNode(list_func_fg));
|
||||
for (std::size_t i = 1; i < inputs.size(); ++i) {
|
||||
(void)new_cnode_inputs.emplace_back(inputs[i]);
|
||||
}
|
||||
auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
|
||||
fg->ReplaceInOrder(cnode, new_cnode);
|
||||
|
||||
AnalysisEnginePtr eng = conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
|
||||
return eng->ForwardConfig(conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -524,19 +681,21 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
|
||||
// Check if func graph contains isolated side-effect, and sync.
|
||||
if (check_isolated_side_effect()) {
|
||||
auto func_graph_abs = mindspore::cast<FuncGraphAbstractClosure>(func);
|
||||
if (func_graph_abs != nullptr) {
|
||||
contains_isolated_side_effect =
|
||||
contains_isolated_side_effect || func_graph_abs->func_graph()->has_isolated_side_effect_node();
|
||||
}
|
||||
auto meta_func_graph_abs = mindspore::cast<MetaFuncGraphAbstractClosure>(func);
|
||||
if (meta_func_graph_abs != nullptr) {
|
||||
contains_isolated_side_effect =
|
||||
contains_isolated_side_effect || meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node();
|
||||
}
|
||||
func->Visit([this, &contains_isolated_side_effect](const AbstractFuncAtomPtr &poss) {
|
||||
MS_EXCEPTION_IF_NULL(poss);
|
||||
auto resolved_atom = poss;
|
||||
auto async_abs_func = poss->cast_ptr<AsyncAbstractFuncAtom>();
|
||||
if (async_abs_func != nullptr) {
|
||||
auto resolved_func = async_abs_func->GetUnique();
|
||||
resolved_atom = dyn_cast<AbstractFuncAtom>(resolved_func);
|
||||
MS_EXCEPTION_IF_NULL(resolved_atom);
|
||||
}
|
||||
contains_isolated_side_effect |= CheckFuncIsolatedSideEffect(resolved_atom, check_isolated_side_effect());
|
||||
});
|
||||
if (contains_isolated_side_effect) {
|
||||
cnode->set_has_isolated_side_effect_node(true);
|
||||
conf->func_graph()->set_has_isolated_side_effect_node(true);
|
||||
eval_result->set_has_isolated_side_effect(true);
|
||||
}
|
||||
}
|
||||
return eval_result;
|
||||
|
@ -996,147 +1155,6 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
|
|||
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
if (abstract->isa<AbstractSequence>()) {
|
||||
auto elements = abstract->cast_ptr<AbstractSequence>()->elements();
|
||||
if (std::any_of(elements.begin(), elements.end(),
|
||||
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
|
||||
std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
|
||||
AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals,
|
||||
trace::TraceCNodeEvalStack trace_c_node_evals) {
|
||||
AnalysisSchedule::set_thread_id(thread_id);
|
||||
// Restore trace stack for dump stack when there is exception.
|
||||
trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
|
||||
trace_c_node_evals.clear();
|
||||
trace::TraceGraphEvalStackPrepare(graph_evals);
|
||||
graph_evals.clear();
|
||||
|
||||
try {
|
||||
// Wait for Signal to run
|
||||
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " waiting.";
|
||||
(void)async_task->GetResult();
|
||||
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " running.";
|
||||
|
||||
// Acquire GIL for eval to callback python.
|
||||
EvalResultPtr result;
|
||||
{
|
||||
MS_LOG(DEBUG) << std::this_thread::get_id() << " begin.";
|
||||
py::gil_scoped_acquire py_guard;
|
||||
result = eval->Run(engine, args_conf_list, out_conf);
|
||||
}
|
||||
MS_LOG(DEBUG) << std::this_thread::get_id() << " end.";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
MS_EXCEPTION_IF_NULL(result->abstract());
|
||||
|
||||
// Check the branch value to be compatible with the other branch value.
|
||||
AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
|
||||
// Broaden the result of switch(c,t,f)()
|
||||
auto broaden_abstract = result->abstract()->Broaden();
|
||||
// Notify the thread of waiting for branch value and the main thread to continue.
|
||||
async_result_branch->set_result(broaden_abstract);
|
||||
async_result_main->set_result(broaden_abstract);
|
||||
MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString()
|
||||
<< " asyncResult address = " << async_result_branch.get()
|
||||
<< " value = " << async_result_branch->TryGetResult()->ToString();
|
||||
} catch (const std::exception &ex) {
|
||||
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
|
||||
AnalysisSchedule::GetInstance().HandleException(ex);
|
||||
}
|
||||
trace::ClearTraceStack();
|
||||
ClearThreadLocal();
|
||||
MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " exited.";
|
||||
// Thread number will be drop when thread exits.
|
||||
AnalysisSchedule::GetInstance().DecreaseThreadCount();
|
||||
}
|
||||
|
||||
AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
|
||||
const std::vector<AsyncAbstractPtr> &pending_async_abstract_list,
|
||||
const std::vector<std::size_t> &index) {
|
||||
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(orig_abs);
|
||||
if (sequence_abs != nullptr) {
|
||||
const auto &orig_elements = sequence_abs->elements();
|
||||
AbstractBasePtrList new_elements;
|
||||
for (size_t i = 0; i < orig_elements.size(); ++i) {
|
||||
if (orig_elements[i]->isa<AbstractFuncAtom>()) {
|
||||
AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast<AbstractFuncAtomPtr>()};
|
||||
for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
|
||||
std::vector<std::size_t> new_index(index);
|
||||
new_index.push_back(i);
|
||||
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], new_index);
|
||||
abs_func_list.push_back(async_func);
|
||||
}
|
||||
new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list));
|
||||
} else if (orig_elements[i]->isa<AbstractSequence>()) {
|
||||
std::vector<std::size_t> new_index(index);
|
||||
new_index.push_back(i);
|
||||
new_elements.push_back(BuildAsyncAbstractRecursively(orig_elements[i], pending_async_abstract_list, new_index));
|
||||
} else {
|
||||
new_elements.push_back(orig_elements[i]);
|
||||
}
|
||||
}
|
||||
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
|
||||
AbstractBasePtr new_abs;
|
||||
if (orig_abs->isa<AbstractTuple>()) {
|
||||
new_abs = std::make_shared<AbstractTuple>(
|
||||
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
|
||||
} else if (orig_abs->isa<AbstractList>()) {
|
||||
new_abs = std::make_shared<AbstractList>(
|
||||
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
|
||||
}
|
||||
return new_abs;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Orig abstract is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
|
||||
}
|
||||
|
||||
void BuildPossibleSpecs(const AbstractBasePtr &first_result,
|
||||
const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
|
||||
AbstractBasePtrList *out_specs) {
|
||||
MS_EXCEPTION_IF_NULL(out_specs);
|
||||
MS_EXCEPTION_IF_NULL(first_result);
|
||||
std::vector<AsyncAbstractPtr> pending_async_abstract_list;
|
||||
std::size_t len = branch_async_abstract_list.size();
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
auto result = branch_async_abstract_list[i]->TryGetResult();
|
||||
if (result) {
|
||||
out_specs->push_back(result);
|
||||
} else {
|
||||
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
|
||||
}
|
||||
}
|
||||
if (first_result->isa<AbstractFunction>()) {
|
||||
for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
|
||||
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector<size_t>{0});
|
||||
out_specs->push_back(async_func);
|
||||
}
|
||||
} else if (first_result->isa<AbstractSequence>()) {
|
||||
const auto &new_first_result =
|
||||
BuildAsyncAbstractRecursively(first_result, pending_async_abstract_list, std::vector<size_t>());
|
||||
MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString()
|
||||
<< ", new: " << new_first_result->ToString();
|
||||
std::replace_if(
|
||||
out_specs->begin(), out_specs->end(), [first_result](const auto &element) { return element == first_result; },
|
||||
new_first_result);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
|
|
|
@ -59,16 +59,23 @@ using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
|
|||
// the class to save evaluated result: abstract value and modified attribute
|
||||
class EvalResult : public Base {
|
||||
public:
|
||||
EvalResult(const AbstractBasePtr &abs, const AttrValueMapPtr &attr) : abstract_(abs), attribute_(attr) {}
|
||||
EvalResult(const AbstractBasePtr &abs, const AttrValueMapPtr &attr)
|
||||
: abstract_(abs), attribute_(attr), has_isolated_side_effect_(false) {}
|
||||
~EvalResult() override = default;
|
||||
MS_DECLARE_PARENT(EvalResult, Base);
|
||||
const AbstractBasePtr &abstract() const { return abstract_; }
|
||||
const AttrValueMapPtr &attribute() const { return attribute_; }
|
||||
bool has_isolated_side_effect() const { return has_isolated_side_effect_; }
|
||||
void set_has_isolated_side_effect(bool has_isolated_side_effect) {
|
||||
has_isolated_side_effect_ = has_isolated_side_effect;
|
||||
}
|
||||
|
||||
private:
|
||||
AbstractBasePtr abstract_;
|
||||
// Attribute related to PrimEvaluator;
|
||||
AttrValueMapPtr attribute_;
|
||||
|
||||
bool has_isolated_side_effect_;
|
||||
};
|
||||
using EvalResultPtr = std::shared_ptr<EvalResult>;
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -101,24 +101,6 @@ bool CheckAbstractScalar(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool CheckIfRaise(const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
auto first = inputs[1];
|
||||
auto script_node = first->cast<ValueNodePtr>();
|
||||
if (script_node->value()->isa<StringImm>()) {
|
||||
auto script = GetValueNode<StringImmPtr>(script_node)->value();
|
||||
std::string raise_script = "raise_func";
|
||||
auto idx = script.find(raise_script);
|
||||
if (idx != string::npos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ValidateAbstract(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
MS_LOG(DEBUG) << "Node to validate is invalid";
|
||||
|
@ -141,11 +123,6 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
|
||||
return;
|
||||
}
|
||||
if (CheckIfRaise(node)) {
|
||||
ShapeVector shp{abstract::Shape::kShapeRankAny};
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(kFloat64, std::make_shared<abstract::Shape>(shp));
|
||||
node->set_abstract(abs);
|
||||
}
|
||||
bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
|
||||
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
|
||||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
|
||||
|
|
|
@ -167,7 +167,6 @@ def test_none_is_default_value_of_parameter():
|
|||
assert res1 is None
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support print x side effect.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -191,7 +190,7 @@ def test_none_is_default_value_of_parameter_2():
|
|||
x = [1, 2]
|
||||
y = [3, 4]
|
||||
res2 = foo(x, y)
|
||||
assert res2 == (3, 4)
|
||||
assert res2 == [3, 4]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -463,7 +462,6 @@ def test_none_is_input_of_tuple_return_2():
|
|||
assert out_me_graph == out_me_pynative
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support print side effect.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -489,3 +487,66 @@ def test_none_is_return_of_sub_graph_control_flow():
|
|||
data = Tensor(np.ones([2, 3]), dtype=ms.float32)
|
||||
out = net(data)
|
||||
assert (out.asnumpy() == data).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_none_is_return_of_sub_graph_control_flow_raise():
|
||||
"""
|
||||
Feature: Support None.
|
||||
Description: Support None is the return of sub_graph in control flow. And Raise node is in sub_graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def inner_func(self, x): # pylint: disable=R1711
|
||||
if x == 2:
|
||||
raise ValueError("The input should not be ", x)
|
||||
return None
|
||||
|
||||
def construct(self, x):
|
||||
self.inner_func(x)
|
||||
return x
|
||||
|
||||
net = RaiseNet()
|
||||
res = net(Tensor(1))
|
||||
print("res:", res)
|
||||
assert res.asnumpy() == 1
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_none_is_return_raise():
|
||||
"""
|
||||
Feature: Support None.
|
||||
Description: Support None is the return of sub_graph in control flow. And Raise node is in sub_graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
def check_test(shp): # pylint: disable=R1711
|
||||
if shp[0] > 5:
|
||||
raise ValueError('raise value error.')
|
||||
return None
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.one = Tensor(1, dtype=ms.float32)
|
||||
|
||||
def construct(self, x):
|
||||
shp = x.shape
|
||||
check_test(shp)
|
||||
return x
|
||||
|
||||
with pytest.raises(ValueError, match="raise value error."):
|
||||
np_data = np.random.randint(6, size=(6,))
|
||||
data = Tensor(np_data, dtype=ms.float32)
|
||||
dyn_tensor = Tensor(shape=[None], dtype=ms.float32)
|
||||
net = Net()
|
||||
net.set_inputs(dyn_tensor)
|
||||
out = net(data)
|
||||
assert out == data
|
||||
|
|
|
@ -228,10 +228,10 @@ def test_raise_with_variable_control_flow2():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
def construct(self, x, y): # pylint: disable=R1711
|
||||
if x == y:
|
||||
raise RuntimeError(f"The input should not be {x}.")
|
||||
return x
|
||||
return None
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
|
||||
net = RaiseNet()
|
||||
|
@ -243,32 +243,6 @@ def test_raise_with_variable_control_flow2():
|
|||
raise_info_joinedstr_tensor.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_with_variable_control_flow3():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
if x == y:
|
||||
raise RuntimeError(f"The input should not be {x}.")
|
||||
return z
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
|
||||
net = RaiseNet()
|
||||
x = Tensor(1)
|
||||
y = Tensor(1)
|
||||
z = (x, y)
|
||||
res = net(x, y, z)
|
||||
print("res:", res)
|
||||
assert "The input should not be 1" in str(
|
||||
raise_info_joinedstr_tensor.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not support yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
|
Loading…
Reference in New Issue