From 425a207663370af08f9215da40ea854146a1d29a Mon Sep 17 00:00:00 2001 From: gongchen Date: Mon, 27 Apr 2020 21:10:09 +0800 Subject: [PATCH] bug(SA): Add the support of nested loop. --- mindspore/ccsrc/optimizer/optimizer.h | 22 +++--- mindspore/ccsrc/pipeline/action.cc | 13 +++- .../pipeline/static_analysis/evaluator.cc | 24 ++++++ .../pipeline/static_analysis/evaluator.h | 6 ++ .../static_analysis/static_analysis.cc | 77 ++++++++++++++++--- .../static_analysis/static_analysis.h | 2 + .../python/pynative_mode/test_framstruct.py | 53 ++++++++++++- .../pynative_mode/test_multigraph_sink.py | 18 ++++- 8 files changed, 190 insertions(+), 25 deletions(-) diff --git a/mindspore/ccsrc/optimizer/optimizer.h b/mindspore/ccsrc/optimizer/optimizer.h index 1a0ddbc65f8..d5808b48188 100644 --- a/mindspore/ccsrc/optimizer/optimizer.h +++ b/mindspore/ccsrc/optimizer/optimizer.h @@ -27,14 +27,13 @@ #include #include -#ifdef DEBUG #include "debug/draw.h" #include "debug/anf_ir_dump.h" -#endif +#include "debug/trace.h" #include "optimizer/opt.h" #include "pipeline/resource.h" #include "pipeline/action.h" -#include "debug/trace.h" +#include "utils/context/ms_context.h" namespace mindspore { namespace opt { @@ -133,7 +132,7 @@ class Optimizer : public std::enable_shared_from_this { FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { // Optimizer step counter; - int counter = 1; + int counter = -1; bool changes = true; while (changes) { @@ -170,13 +169,14 @@ class Optimizer : public std::enable_shared_from_this { } }; use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); -#ifdef DEBUG - MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; - auto fg_name = name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; - func_graph->DumpFuncGraph(fg_name); - DumpIR(fg_name + ".ir", func_graph); - MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; -#endif + if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) { + MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; + auto fg_name = + "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; + func_graph->DumpFuncGraph(fg_name); + DumpIR(fg_name + ".ir", func_graph); + MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; + } } }; use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter++)) run_runc) : run_runc(); diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index f15723d64d3..b22e9c9993a 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -32,6 +32,7 @@ #include "pipeline/static_analysis/static_analysis.h" #include "pipeline/static_analysis/program_specialize.h" #include "pipeline/resource.h" +#include "utils/context/ms_context.h" #include "pipeline/remove_value_node_dup.h" #include "optimizer/optimizer.h" #include "vm/transform.h" @@ -240,13 +241,23 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { } bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) { + size_t counter = 0; for (auto &pass : passes) { - WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res]() { + WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() { MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; auto result = pass.second(res); if (!result) { MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; } + if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { + auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; + auto func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + func_graph->DumpFuncGraph(fg_name); + DumpIR(fg_name + ".ir", func_graph); + MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; + } + counter++; MS_LOG(DEBUG) << "Pass " << pass.first << " end."; }; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc index 06d61292d7e..e1a743c36ae 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc @@ -55,6 +55,7 @@ void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList & AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) { AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); + normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list); FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); MS_EXCEPTION_IF_NULL(parent_context_); AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); @@ -140,7 +141,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList << ", broaded: " << mindspore::ToString(broaded_list); return broaded_list; } + return args_spec_list; +} +AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(func_graph_); + if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { + return args_spec_list; + } if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { if (parent_context_) { MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() @@ -160,6 +168,21 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList return joined_args_spec_list; } } + if (trace_.size() != 0) { + MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); + MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); + // Join the last eval arguments and current arguments to check if there are loop variant. + auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back()); + // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. + if (!(joined_args_spec_list == args_spec_list)) { + trace_.push_back(joined_args_spec_list); + func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + } + MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); + return joined_args_spec_list; + } else { + trace_.push_back(args_spec_list); + } } return args_spec_list; } @@ -224,6 +247,7 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar return conf->GetEvaluatedValue(); }); args_spec_list = NormalizeArgs(args_spec_list); + args_spec_list = BroadenUndeterminedArgs(args_spec_list); trace::TraceGraphInferEnter(shared_from_base(), out_conf); InferEntryLogging(shared_from_base(), args_spec_list, out_conf); MS_EXCEPTION_IF_NULL(cache_); diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h index c745bca1e93..5f06364500b 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h @@ -47,6 +47,10 @@ class Evaluator : public Base { virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } + virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { + return args_spec_list; + } + std::string ToString() const override { return identifier_; } virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } @@ -181,12 +185,14 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator { FuncGraphPtr func_graph() { return func_graph_; } AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override; + AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override; std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } private: FuncGraphPtr func_graph_; std::unordered_map func_graph_cache_; + std::vector trace_; }; using FuncGraphEvaluatorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc index c5ee7447f1d..0de4ce4e995 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc @@ -19,6 +19,7 @@ #include "pipeline/static_analysis/static_analysis.h" #include +#include #include "pipeline/static_analysis/utils.h" #include "pipeline/static_analysis/prim.h" @@ -239,7 +240,6 @@ AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeC for (std::size_t i = 1; i < inputs.size(); i++) { const AnfNodePtr &node = inputs[i]; args_conf_list.push_back(MakeConfig(node, context)); - MS_LOG(DEBUG) << "Current CNode args_conf_list[" << i << "] node: " << node->DebugString(); } std::vector infs; @@ -469,6 +469,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector AbstractBasePtr { @@ -478,28 +482,81 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorcast(); if (fg_eval) { - auto undetermined_fgs = fg_eval->func_graph()->recursive_graphs(); + auto fg = fg_eval->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto undetermined_fgs = fg->recursive_graphs(); if (undetermined_fgs) { - for (auto undetermined_fg : *undetermined_fgs) { - MS_LOG(DEBUG) << "Set graph undetermined: " << undetermined_fg->ToString(); - // As the current evaluator has multiple possibles, all the func_graphs which - // are recursive with the current func_graph are undetermined in control flow. - undetermined_fg->set_flags(kFuncGraphFlagUndetermined, true); - } + auto fg_parent = fg->parent(); + MS_EXCEPTION_IF_NULL(fg_parent); + fg_parent->set_flags(kFuncGraphFlagUndetermined, true); + MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); } } auto current_inf = std::make_pair(eval, args_spec_list); + MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); + // If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring. - auto it = std::find(eval_trace_.begin(), eval_trace_.end(), current_inf); - if (it == eval_trace_.end()) { + auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); + if (it == eval_trace_.rend()) { eval_trace_.push_back(current_inf); + MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); MS_EXCEPTION_IF_NULL(eval); auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf); MS_EXCEPTION_IF_NULL(out_spec); MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString(); out_specs.push_back(out_spec); + MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString(); eval_trace_.pop_back(); + if (eval_trace_.empty()) { + multi_poss_.clear(); + } + } else if (it != eval_trace_.rbegin()) { + // Find latest entry function to handle nested recursion. + EvaluatorPtr latest_entry = eval; + auto latest_entry_iter = eval_trace_.rbegin(); + for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { + auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); + if (it_temp != evaluators.end()) { + latest_entry = *it_temp; + latest_entry_iter = r_it; + break; + } + latest_entry_iter = ++r_it; + } + if (latest_entry != eval) { + MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); + continue; + } + + bool has_undetermined = false; + // Check whether sub loop has untraced undetermined evaluator. + std::set> undetermined_evals; + for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { + undetermined_evals.insert(*r_it); + } + MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); + for (auto u_eval : undetermined_evals) { + MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; + if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { + MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; + has_undetermined = true; + break; + } + } + if (has_undetermined == false) { + MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; + continue; + } + + // Try to travel the latest undetermined. + if (latest_entry != eval_trace_.rbegin()->first) { + MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); + auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); + MS_EXCEPTION_IF_NULL(out_spec); + MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString(); + return out_spec; + } } } if (out_specs.size() == 0) { diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h index beffb9ee70a..549187c29ee 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h @@ -25,6 +25,7 @@ #include #include #include +#include #ifdef DEBUG #include @@ -206,6 +207,7 @@ class AnalysisEngine : public std::enable_shared_from_this { AnfNodeConfigMap anfnode_config_map_; // Use a list to trace multiple evaluators. std::list> eval_trace_; + std::map multi_poss_; AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, const ConfigPtrList &args_conf_list); diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index 7e504c405fb..355ce85bb9b 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -39,7 +39,6 @@ from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer def setup_module(module): context.set_context(mode=context.PYNATIVE_MODE) - @ms_function def while_upper_bound(upper): rval = 2 @@ -392,6 +391,58 @@ def test_grad_factorial(): res = C.grad(factorial)(3) assert res == 11 +@ms_function +def factorial2(n): + """ factorial """ + if n != 0: + return n * factorial2(n-1) + elif n == 1: + return 1 * factorial2(n-1) + else: + return 1 +def test_factorial2(): + res = factorial2(3) + assert res == 6 + +@ms_function +def foo(n): + if n <= 1: + if n == 1: + return foo(n-1) + else: + return 1 + else: + return foo(n-1) +def test_foo(): + res = foo(5) + assert res == 1 + +@ms_function +def double_nested_loop(x): + i = 0 + s = 0 + while(i < x): + j = 0 + i = i + 1 + while(j < 3): + j = j + 1 + s = s + j + return s +def test_nested_loop(): + res = double_nested_loop(3) + assert res == 18 + +@ms_function +def double_nested_loop2(x): + s = 0 + for i in range(x): + for j in range(3): + s = s + j + return s +def test_nested_loop2(): + res = double_nested_loop(1) + assert res == 6 + def _for(x): """ _for """ ret = x * x diff --git a/tests/ut/python/pynative_mode/test_multigraph_sink.py b/tests/ut/python/pynative_mode/test_multigraph_sink.py index 0c69c7c2c1a..bf3d5b500da 100644 --- a/tests/ut/python/pynative_mode/test_multigraph_sink.py +++ b/tests/ut/python/pynative_mode/test_multigraph_sink.py @@ -24,7 +24,7 @@ from mindspore.ops import operations as P def setup_module(module): - context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend") + context.set_context(mode = context.PYNATIVE_MODE, save_graphs = False, device_target = "Ascend") context.set_context(enable_task_sink = True, device_id = 0) @@ -86,7 +86,17 @@ def while_by_while(x, y, z): x = x + 1 x = x + 1 return x - +@ms_function +def while_in_while(x, y, z): + out = c4 + while x < y: + z = c4 + c4 + while z < y: + z = z + 1 + out = out + z + x = x + 1 + out = out + x + return out def test_simple_if(): output = simple_if(c1, c2, c3) @@ -117,3 +127,7 @@ def test_while_by_while(): expect = Tensor([28], mstype.int32) assert output == expect +def test_while_in_while(): + output = while_in_while(c1, c2, c3) + expect = Tensor([1274], mstype.int32) + assert output == expect \ No newline at end of file