Support to return isolated node in control flow branches.

This commit is contained in:
张清华 2023-03-06 11:10:27 +08:00
parent 19735e8afd
commit fd91357363
2 changed files with 213 additions and 197 deletions

View File

@ -64,6 +64,208 @@ void DecreaseStackFrameDepth() {
}
size_t StackFrameDepth() { return stack_frame_depth; }
namespace {
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<AbstractFunction>()) {
return true;
}
if (abstract->isa<AbstractSequence>()) {
auto elements = abstract->cast_ptr<AbstractSequence>()->elements();
if (std::any_of(elements.begin(), elements.end(),
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
return true;
}
}
return false;
}
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals,
trace::TraceCNodeEvalStack trace_c_node_evals) {
AnalysisSchedule::set_thread_id(thread_id);
// Restore trace stack for dump stack when there is exception.
trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
trace_c_node_evals.clear();
trace::TraceGraphEvalStackPrepare(graph_evals);
graph_evals.clear();
try {
// Wait for Signal to run
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " waiting.";
(void)async_task->GetResult();
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " running.";
// Acquire GIL for eval to callback python.
EvalResultPtr result;
{
MS_LOG(DEBUG) << std::this_thread::get_id() << " begin.";
py::gil_scoped_acquire py_guard;
result = eval->Run(engine, args_conf_list, out_conf);
}
MS_LOG(DEBUG) << std::this_thread::get_id() << " end.";
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(result->abstract());
// Check the branch value to be compatible with the other branch value.
AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
// Broaden the result of switch(c,t,f)()
auto broaden_abstract = result->abstract()->Broaden();
// Notify the thread of waiting for branch value and the main thread to continue.
async_result_branch->set_result(broaden_abstract);
async_result_main->set_result(broaden_abstract);
MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString()
<< " asyncResult address = " << async_result_branch.get()
<< " value = " << async_result_branch->TryGetResult()->ToString();
} catch (const std::exception &ex) {
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
AnalysisSchedule::GetInstance().HandleException(ex);
}
trace::ClearTraceStack();
ClearThreadLocal();
MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " exited.";
// Thread number will be drop when thread exits.
AnalysisSchedule::GetInstance().DecreaseThreadCount();
}
AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
const std::vector<AsyncAbstractPtr> &pending_async_abstract_list,
const std::vector<std::size_t> &index) {
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(orig_abs);
if (sequence_abs != nullptr) {
const auto &orig_elements = sequence_abs->elements();
AbstractBasePtrList new_elements;
for (size_t i = 0; i < orig_elements.size(); ++i) {
if (orig_elements[i]->isa<AbstractFuncAtom>()) {
AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast<AbstractFuncAtomPtr>()};
for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
std::vector<std::size_t> new_index(index);
new_index.push_back(i);
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], new_index);
abs_func_list.push_back(async_func);
}
new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list));
} else if (orig_elements[i]->isa<AbstractSequence>()) {
std::vector<std::size_t> new_index(index);
new_index.push_back(i);
new_elements.push_back(BuildAsyncAbstractRecursively(orig_elements[i], pending_async_abstract_list, new_index));
} else {
new_elements.push_back(orig_elements[i]);
}
}
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
AbstractBasePtr new_abs;
if (orig_abs->isa<AbstractTuple>()) {
new_abs = std::make_shared<AbstractTuple>(
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
} else if (orig_abs->isa<AbstractList>()) {
new_abs = std::make_shared<AbstractList>(
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
} else {
MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
}
return new_abs;
}
MS_LOG(EXCEPTION) << "Orig abstract is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
}
void BuildPossibleSpecs(const AbstractBasePtr &first_result,
const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
AbstractBasePtrList *out_specs) {
MS_EXCEPTION_IF_NULL(out_specs);
MS_EXCEPTION_IF_NULL(first_result);
std::vector<AsyncAbstractPtr> pending_async_abstract_list;
std::size_t len = branch_async_abstract_list.size();
for (size_t i = 0; i < len; ++i) {
auto result = branch_async_abstract_list[i]->TryGetResult();
if (result) {
out_specs->push_back(result);
} else {
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
}
}
if (first_result->isa<AbstractFunction>()) {
for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector<size_t>{0});
out_specs->push_back(async_func);
}
} else if (first_result->isa<AbstractSequence>()) {
const auto &new_first_result =
BuildAsyncAbstractRecursively(first_result, pending_async_abstract_list, std::vector<size_t>());
MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString()
<< ", new: " << new_first_result->ToString();
std::replace_if(
out_specs->begin(), out_specs->end(), [first_result](const auto &element) { return element == first_result; },
new_first_result);
} else {
MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
}
}
void CheckInterpretedObject(const AbstractBasePtr &abs) {
static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
static const auto use_fallback = (support_fallback != "0");
if (!use_fallback) {
return;
}
auto value = abs->BuildValue();
if (value->isa<parse::InterpretedObject>()) {
MS_LOG(ERROR) << "Do not support " << value->ToString() << ". "
<< "\nIf you are using third-party modules, you can try setting: "
<< "'export MS_DEV_SUPPORT_MODULES=module1,module2,...'.";
}
}
EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) {
auto val = abs->BuildValue();
auto class_val = dyn_cast_ptr<parse::ClassType>(val);
const auto &class_name = class_val->name();
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
auto py_fn = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name));
if (py::isinstance<py::none>(py_fn)) {
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << ".";
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
MS_EXCEPTION(ValueError) << "Can not call " << class_name << " to create python object in graph mode. "
<< "Try using 'jit_class' to decorate the class?";
}
auto list_func_fg = parse::ParsePythonCode(py_fn);
auto fg = cnode->func_graph();
list_func_fg->set_manager(fg->manager());
auto &inputs = cnode->inputs();
std::vector<AnfNodePtr> new_cnode_inputs;
(void)new_cnode_inputs.emplace_back(NewValueNode(list_func_fg));
for (std::size_t i = 1; i < inputs.size(); ++i) {
(void)new_cnode_inputs.emplace_back(inputs[i]);
}
auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
fg->ReplaceInOrder(cnode, new_cnode);
AnalysisEnginePtr eng = conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
return eng->ForwardConfig(conf, fn_conf);
}
bool CheckFuncIsolatedSideEffect(const AbstractFunctionPtr &func, bool check_isolated_side_effect) {
// Check if func graph contains isolated side-effect, and sync.
if (check_isolated_side_effect) {
auto func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(func);
if (func_graph_abs != nullptr) {
return func_graph_abs->func_graph()->has_isolated_side_effect_node();
} else {
auto meta_func_graph_abs = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(func);
if (meta_func_graph_abs != nullptr) {
return meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node();
}
}
}
return false;
}
} // namespace
EvalResultPtr PrimitiveEvalCache::Get(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
MS_EXCEPTION_IF_NULL(prim);
std::lock_guard<std::mutex> guard(mutex_);
@ -407,51 +609,6 @@ AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode,
return possible_func;
}
void CheckInterpretedObject(const AbstractBasePtr &abs) {
static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
static const auto use_fallback = (support_fallback != "0");
if (!use_fallback) {
return;
}
auto value = abs->BuildValue();
if (value->isa<parse::InterpretedObject>()) {
MS_LOG(ERROR) << "Do not support " << value->ToString() << ". "
<< "\nIf you are using third-party modules, you can try setting: "
<< "'export MS_DEV_SUPPORT_MODULES=module1,module2,...'.";
}
}
EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) {
auto val = abs->BuildValue();
auto class_val = dyn_cast_ptr<parse::ClassType>(val);
const auto &class_name = class_val->name();
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
auto py_fn = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name));
if (py::isinstance<py::none>(py_fn)) {
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << ".";
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
MS_EXCEPTION(ValueError) << "Can not call " << class_name << " to create python object in graph mode. "
<< "Try using 'jit_class' to decorate the class?";
}
auto list_func_fg = parse::ParsePythonCode(py_fn);
auto fg = cnode->func_graph();
list_func_fg->set_manager(fg->manager());
auto &inputs = cnode->inputs();
std::vector<AnfNodePtr> new_cnode_inputs;
(void)new_cnode_inputs.emplace_back(NewValueNode(list_func_fg));
for (std::size_t i = 1; i < inputs.size(); ++i) {
(void)new_cnode_inputs.emplace_back(inputs[i]);
}
auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
fg->ReplaceInOrder(cnode, new_cnode);
AnalysisEnginePtr eng = conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
return eng->ForwardConfig(conf, fn_conf);
}
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cnode);
@ -524,16 +681,17 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
// Check if func graph contains isolated side-effect, and sync.
if (check_isolated_side_effect()) {
auto func_graph_abs = mindspore::cast<FuncGraphAbstractClosure>(func);
if (func_graph_abs != nullptr) {
contains_isolated_side_effect =
contains_isolated_side_effect || func_graph_abs->func_graph()->has_isolated_side_effect_node();
}
auto meta_func_graph_abs = mindspore::cast<MetaFuncGraphAbstractClosure>(func);
if (meta_func_graph_abs != nullptr) {
contains_isolated_side_effect =
contains_isolated_side_effect || meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node();
}
func->Visit([this, &contains_isolated_side_effect](const AbstractFuncAtomPtr &poss) {
MS_EXCEPTION_IF_NULL(poss);
auto resolved_atom = poss;
auto async_abs_func = poss->cast_ptr<AsyncAbstractFuncAtom>();
if (async_abs_func != nullptr) {
auto resolved_func = async_abs_func->GetUnique();
resolved_atom = dyn_cast<AbstractFuncAtom>(resolved_func);
MS_EXCEPTION_IF_NULL(resolved_atom);
}
contains_isolated_side_effect |= CheckFuncIsolatedSideEffect(resolved_atom, check_isolated_side_effect());
});
if (contains_isolated_side_effect) {
cnode->set_has_isolated_side_effect_node(true);
conf->func_graph()->set_has_isolated_side_effect_node(true);
@ -996,147 +1154,6 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
}
namespace {
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<AbstractFunction>()) {
return true;
}
if (abstract->isa<AbstractSequence>()) {
auto elements = abstract->cast_ptr<AbstractSequence>()->elements();
if (std::any_of(elements.begin(), elements.end(),
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
return true;
}
}
return false;
}
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals,
trace::TraceCNodeEvalStack trace_c_node_evals) {
AnalysisSchedule::set_thread_id(thread_id);
// Restore trace stack for dump stack when there is exception.
trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
trace_c_node_evals.clear();
trace::TraceGraphEvalStackPrepare(graph_evals);
graph_evals.clear();
try {
// Wait for Signal to run
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " waiting.";
(void)async_task->GetResult();
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " running.";
// Acquire GIL for eval to callback python.
EvalResultPtr result;
{
MS_LOG(DEBUG) << std::this_thread::get_id() << " begin.";
py::gil_scoped_acquire py_guard;
result = eval->Run(engine, args_conf_list, out_conf);
}
MS_LOG(DEBUG) << std::this_thread::get_id() << " end.";
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(result->abstract());
// Check the branch value to be compatible with the other branch value.
AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
// Broaden the result of switch(c,t,f)()
auto broaden_abstract = result->abstract()->Broaden();
// Notify the thread of waiting for branch value and the main thread to continue.
async_result_branch->set_result(broaden_abstract);
async_result_main->set_result(broaden_abstract);
MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString()
<< " asyncResult address = " << async_result_branch.get()
<< " value = " << async_result_branch->TryGetResult()->ToString();
} catch (const std::exception &ex) {
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
AnalysisSchedule::GetInstance().HandleException(ex);
}
trace::ClearTraceStack();
ClearThreadLocal();
MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " exited.";
// Thread number will be drop when thread exits.
AnalysisSchedule::GetInstance().DecreaseThreadCount();
}
AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
const std::vector<AsyncAbstractPtr> &pending_async_abstract_list,
const std::vector<std::size_t> &index) {
auto sequence_abs = dyn_cast_ptr<AbstractSequence>(orig_abs);
if (sequence_abs != nullptr) {
const auto &orig_elements = sequence_abs->elements();
AbstractBasePtrList new_elements;
for (size_t i = 0; i < orig_elements.size(); ++i) {
if (orig_elements[i]->isa<AbstractFuncAtom>()) {
AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast<AbstractFuncAtomPtr>()};
for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
std::vector<std::size_t> new_index(index);
new_index.push_back(i);
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], new_index);
abs_func_list.push_back(async_func);
}
new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list));
} else if (orig_elements[i]->isa<AbstractSequence>()) {
std::vector<std::size_t> new_index(index);
new_index.push_back(i);
new_elements.push_back(BuildAsyncAbstractRecursively(orig_elements[i], pending_async_abstract_list, new_index));
} else {
new_elements.push_back(orig_elements[i]);
}
}
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
AbstractBasePtr new_abs;
if (orig_abs->isa<AbstractTuple>()) {
new_abs = std::make_shared<AbstractTuple>(
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
} else if (orig_abs->isa<AbstractList>()) {
new_abs = std::make_shared<AbstractList>(
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : nullptr));
} else {
MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
}
return new_abs;
}
MS_LOG(EXCEPTION) << "Orig abstract is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
}
void BuildPossibleSpecs(const AbstractBasePtr &first_result,
const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
AbstractBasePtrList *out_specs) {
MS_EXCEPTION_IF_NULL(out_specs);
MS_EXCEPTION_IF_NULL(first_result);
std::vector<AsyncAbstractPtr> pending_async_abstract_list;
std::size_t len = branch_async_abstract_list.size();
for (size_t i = 0; i < len; ++i) {
auto result = branch_async_abstract_list[i]->TryGetResult();
if (result) {
out_specs->push_back(result);
} else {
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
}
}
if (first_result->isa<AbstractFunction>()) {
for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector<size_t>{0});
out_specs->push_back(async_func);
}
} else if (first_result->isa<AbstractSequence>()) {
const auto &new_first_result =
BuildAsyncAbstractRecursively(first_result, pending_async_abstract_list, std::vector<size_t>());
MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString()
<< ", new: " << new_first_result->ToString();
std::replace_if(
out_specs->begin(), out_specs->end(), [first_result](const auto &element) { return element == first_result; },
new_first_result);
} else {
MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
}
}
} // namespace
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {

View File

@ -463,7 +463,6 @@ def test_none_is_input_of_tuple_return_2():
assert out_me_graph == out_me_pynative
@pytest.mark.skip(reason="No support print side effect.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training