forked from mindspore-Ecosystem/mindspore
!766 bugfix(SA): Add the support of nested loop.
Merge pull request !766 from gongchen/nest_loop
This commit is contained in:
commit
8003a89a7b
|
@ -27,14 +27,13 @@
|
|||
#include <utility>
|
||||
#include <initializer_list>
|
||||
|
||||
#ifdef DEBUG
|
||||
#include "debug/draw.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#endif
|
||||
#include "debug/trace.h"
|
||||
#include "optimizer/opt.h"
|
||||
#include "pipeline/resource.h"
|
||||
#include "pipeline/action.h"
|
||||
#include "debug/trace.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -133,7 +132,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
|
||||
FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
|
||||
// Optimizer step counter;
|
||||
int counter = 1;
|
||||
int counter = -1;
|
||||
bool changes = true;
|
||||
|
||||
while (changes) {
|
||||
|
@ -170,13 +169,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
}
|
||||
};
|
||||
use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
|
||||
#ifdef DEBUG
|
||||
MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
|
||||
auto fg_name = name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
|
||||
func_graph->DumpFuncGraph(fg_name);
|
||||
DumpIR(fg_name + ".ir", func_graph);
|
||||
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
|
||||
#endif
|
||||
if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) {
|
||||
MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
|
||||
auto fg_name =
|
||||
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
|
||||
func_graph->DumpFuncGraph(fg_name);
|
||||
DumpIR(fg_name + ".ir", func_graph);
|
||||
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
|
||||
}
|
||||
}
|
||||
};
|
||||
use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter++)) run_runc) : run_runc();
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "pipeline/static_analysis/static_analysis.h"
|
||||
#include "pipeline/static_analysis/program_specialize.h"
|
||||
#include "pipeline/resource.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "pipeline/remove_value_node_dup.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "vm/transform.h"
|
||||
|
@ -240,13 +241,23 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) {
|
||||
size_t counter = 0;
|
||||
for (auto &pass : passes) {
|
||||
WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res]() {
|
||||
WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() {
|
||||
MS_LOG(DEBUG) << "Pass " << pass.first << " start ...";
|
||||
auto result = pass.second(res);
|
||||
if (!result) {
|
||||
MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
|
||||
}
|
||||
if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) {
|
||||
auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
|
||||
auto func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph->DumpFuncGraph(fg_name);
|
||||
DumpIR(fg_name + ".ir", func_graph);
|
||||
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
|
||||
}
|
||||
counter++;
|
||||
MS_LOG(DEBUG) << "Pass " << pass.first << " end.";
|
||||
};
|
||||
}
|
||||
|
|
|
@ -55,6 +55,7 @@ void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &
|
|||
AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list);
|
||||
normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list);
|
||||
FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(parent_context_);
|
||||
AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list);
|
||||
|
@ -140,7 +141,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
|
|||
<< ", broaded: " << mindspore::ToString(broaded_list);
|
||||
return broaded_list;
|
||||
}
|
||||
return args_spec_list;
|
||||
}
|
||||
|
||||
AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||
return args_spec_list;
|
||||
}
|
||||
if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
|
||||
if (parent_context_) {
|
||||
MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString()
|
||||
|
@ -160,6 +168,21 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
|
|||
return joined_args_spec_list;
|
||||
}
|
||||
}
|
||||
if (trace_.size() != 0) {
|
||||
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
|
||||
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back());
|
||||
// Join the last eval arguments and current arguments to check if there are loop variant.
|
||||
auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back());
|
||||
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
|
||||
if (!(joined_args_spec_list == args_spec_list)) {
|
||||
trace_.push_back(joined_args_spec_list);
|
||||
func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
|
||||
return joined_args_spec_list;
|
||||
} else {
|
||||
trace_.push_back(args_spec_list);
|
||||
}
|
||||
}
|
||||
return args_spec_list;
|
||||
}
|
||||
|
@ -224,6 +247,7 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
|
|||
return conf->GetEvaluatedValue();
|
||||
});
|
||||
args_spec_list = NormalizeArgs(args_spec_list);
|
||||
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
||||
trace::TraceGraphInferEnter(shared_from_base<Evaluator>(), out_conf);
|
||||
InferEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(cache_);
|
||||
|
|
|
@ -47,6 +47,10 @@ class Evaluator : public Base {
|
|||
|
||||
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
|
||||
|
||||
virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) {
|
||||
return args_spec_list;
|
||||
}
|
||||
|
||||
std::string ToString() const override { return identifier_; }
|
||||
|
||||
virtual AnfNodePtr bound_node() const { return bound_node_.lock(); }
|
||||
|
@ -181,12 +185,14 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
|
|||
FuncGraphPtr func_graph() { return func_graph_; }
|
||||
|
||||
AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override;
|
||||
AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); }
|
||||
|
||||
private:
|
||||
FuncGraphPtr func_graph_;
|
||||
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>
|
||||
func_graph_cache_;
|
||||
std::vector<AbstractBasePtrList> trace_;
|
||||
};
|
||||
using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>;
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "pipeline/static_analysis/static_analysis.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include "pipeline/static_analysis/utils.h"
|
||||
#include "pipeline/static_analysis/prim.h"
|
||||
|
@ -239,7 +240,6 @@ AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeC
|
|||
for (std::size_t i = 1; i < inputs.size(); i++) {
|
||||
const AnfNodePtr &node = inputs[i];
|
||||
args_conf_list.push_back(MakeConfig(node, context));
|
||||
MS_LOG(DEBUG) << "Current CNode args_conf_list[" << i << "] node: " << node->DebugString();
|
||||
}
|
||||
std::vector<EvaluatorPtr> infs;
|
||||
|
||||
|
@ -469,6 +469,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
AbstractBasePtrList out_specs;
|
||||
if (!multi_poss_.count(evaluators[0])) {
|
||||
multi_poss_[evaluators[0]] = evaluators[1];
|
||||
multi_poss_[evaluators[1]] = evaluators[0];
|
||||
}
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
|
@ -478,28 +482,81 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|||
for (auto eval : evaluators) {
|
||||
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
|
||||
if (fg_eval) {
|
||||
auto undetermined_fgs = fg_eval->func_graph()->recursive_graphs();
|
||||
auto fg = fg_eval->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto undetermined_fgs = fg->recursive_graphs();
|
||||
if (undetermined_fgs) {
|
||||
for (auto undetermined_fg : *undetermined_fgs) {
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << undetermined_fg->ToString();
|
||||
// As the current evaluator has multiple possibles, all the func_graphs which
|
||||
// are recursive with the current func_graph are undetermined in control flow.
|
||||
undetermined_fg->set_flags(kFuncGraphFlagUndetermined, true);
|
||||
}
|
||||
auto fg_parent = fg->parent();
|
||||
MS_EXCEPTION_IF_NULL(fg_parent);
|
||||
fg_parent->set_flags(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
auto current_inf = std::make_pair(eval, args_spec_list);
|
||||
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
||||
|
||||
// If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring.
|
||||
auto it = std::find(eval_trace_.begin(), eval_trace_.end(), current_inf);
|
||||
if (it == eval_trace_.end()) {
|
||||
auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
|
||||
if (it == eval_trace_.rend()) {
|
||||
eval_trace_.push_back(current_inf);
|
||||
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_spec);
|
||||
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString();
|
||||
out_specs.push_back(out_spec);
|
||||
MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString();
|
||||
eval_trace_.pop_back();
|
||||
if (eval_trace_.empty()) {
|
||||
multi_poss_.clear();
|
||||
}
|
||||
} else if (it != eval_trace_.rbegin()) {
|
||||
// Find latest entry function to handle nested recursion.
|
||||
EvaluatorPtr latest_entry = eval;
|
||||
auto latest_entry_iter = eval_trace_.rbegin();
|
||||
for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
|
||||
auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first);
|
||||
if (it_temp != evaluators.end()) {
|
||||
latest_entry = *it_temp;
|
||||
latest_entry_iter = r_it;
|
||||
break;
|
||||
}
|
||||
latest_entry_iter = ++r_it;
|
||||
}
|
||||
if (latest_entry != eval) {
|
||||
MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
|
||||
continue;
|
||||
}
|
||||
|
||||
bool has_undetermined = false;
|
||||
// Check whether sub loop has untraced undetermined evaluator.
|
||||
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals;
|
||||
for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
|
||||
undetermined_evals.insert(*r_it);
|
||||
}
|
||||
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
|
||||
for (auto u_eval : undetermined_evals) {
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
|
||||
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
|
||||
has_undetermined = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_undetermined == false) {
|
||||
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to travel the latest undetermined.
|
||||
if (latest_entry != eval_trace_.rbegin()->first) {
|
||||
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
|
||||
auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_spec);
|
||||
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString();
|
||||
return out_spec;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (out_specs.size() == 0) {
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
|
||||
#ifdef DEBUG
|
||||
#include <stack>
|
||||
|
@ -206,6 +207,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
AnfNodeConfigMap anfnode_config_map_;
|
||||
// Use a list to trace multiple evaluators.
|
||||
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
|
||||
std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
|
||||
|
||||
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
|
|
|
@ -39,7 +39,6 @@ from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
|||
def setup_module(module):
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
@ms_function
|
||||
def while_upper_bound(upper):
|
||||
rval = 2
|
||||
|
@ -392,6 +391,58 @@ def test_grad_factorial():
|
|||
res = C.grad(factorial)(3)
|
||||
assert res == 11
|
||||
|
||||
@ms_function
|
||||
def factorial2(n):
|
||||
""" factorial """
|
||||
if n != 0:
|
||||
return n * factorial2(n-1)
|
||||
elif n == 1:
|
||||
return 1 * factorial2(n-1)
|
||||
else:
|
||||
return 1
|
||||
def test_factorial2():
|
||||
res = factorial2(3)
|
||||
assert res == 6
|
||||
|
||||
@ms_function
|
||||
def foo(n):
|
||||
if n <= 1:
|
||||
if n == 1:
|
||||
return foo(n-1)
|
||||
else:
|
||||
return 1
|
||||
else:
|
||||
return foo(n-1)
|
||||
def test_foo():
|
||||
res = foo(5)
|
||||
assert res == 1
|
||||
|
||||
@ms_function
|
||||
def double_nested_loop(x):
|
||||
i = 0
|
||||
s = 0
|
||||
while(i < x):
|
||||
j = 0
|
||||
i = i + 1
|
||||
while(j < 3):
|
||||
j = j + 1
|
||||
s = s + j
|
||||
return s
|
||||
def test_nested_loop():
|
||||
res = double_nested_loop(3)
|
||||
assert res == 18
|
||||
|
||||
@ms_function
|
||||
def double_nested_loop2(x):
|
||||
s = 0
|
||||
for i in range(x):
|
||||
for j in range(3):
|
||||
s = s + j
|
||||
return s
|
||||
def test_nested_loop2():
|
||||
res = double_nested_loop(1)
|
||||
assert res == 6
|
||||
|
||||
def _for(x):
|
||||
""" _for """
|
||||
ret = x * x
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.ops import operations as P
|
|||
|
||||
|
||||
def setup_module(module):
|
||||
context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend")
|
||||
context.set_context(mode = context.PYNATIVE_MODE, save_graphs = False, device_target = "Ascend")
|
||||
context.set_context(enable_task_sink = True, device_id = 0)
|
||||
|
||||
|
||||
|
@ -86,7 +86,17 @@ def while_by_while(x, y, z):
|
|||
x = x + 1
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
@ms_function
|
||||
def while_in_while(x, y, z):
|
||||
out = c4
|
||||
while x < y:
|
||||
z = c4 + c4
|
||||
while z < y:
|
||||
z = z + 1
|
||||
out = out + z
|
||||
x = x + 1
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
def test_simple_if():
|
||||
output = simple_if(c1, c2, c3)
|
||||
|
@ -117,3 +127,7 @@ def test_while_by_while():
|
|||
expect = Tensor([28], mstype.int32)
|
||||
assert output == expect
|
||||
|
||||
def test_while_in_while():
|
||||
output = while_in_while(c1, c2, c3)
|
||||
expect = Tensor([1274], mstype.int32)
|
||||
assert output == expect
|
Loading…
Reference in New Issue