diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 253d2d08aeb..b15637e7beb 100755 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -800,45 +800,77 @@ void AscendSession::UpdateGraphOrder(GraphId to_graph_id) { } } +size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) { + auto output_num = AnfAlgo::GetOutputTensorNum(node); + if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + return input_index + output_num; + } + auto &graph_inputs = graph->inputs(); + auto &valid_inputs = graph->ValidInputs(); + if (valid_inputs[input_index]) { + SetChildGraphParameter(node, graph_inputs[input_index]); + } else { + MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); + } + return ++input_index; +} + +size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) { + MS_EXCEPTION_IF_NULL(value); + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); + } + auto &graph_inputs = graph->inputs(); + SetChildGraphParameter(value->cast(), graph_inputs[input_index]); + return ++input_index; +} + +size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index) { + auto index = input_index; + for (auto &arg : vec_args) { + if (utils::isa(arg)) { + // arg is a anf node + auto node = utils::cast(arg); + index = SetChildGraphInput(graph, node, input_index); + } else if (utils::isa(arg)) { + // arg is a tensor + auto value = utils::cast(arg); + index = SetChildGraphInput(graph, value, input_index); + } else { + MS_LOG(EXCEPTION) << "Unexpected arg type " << arg.ToString(); + } + } + return index; +} + void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { MS_LOG(INFO) << "Set input of graph " << g; auto to_graph = GetGraph(g); MS_EXCEPTION_IF_NULL(to_graph); DumpGraphInputArgs(args); UpdateGraphOrder(g); - std::vector graph_inputs = to_graph->inputs(); - auto valid_inputs = to_graph->ValidInputs(); + auto &graph_inputs = to_graph->inputs(); auto real_args = GetRealArgs(to_graph, args); size_t input_index = 0; for (size_t i = 0; i < real_args.size(); i++) { if (input_index >= graph_inputs.size()) { MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); } - if (utils::isa(real_args[i])) { + auto &real_arg = real_args[i]; + if (utils::isa(real_arg)) { // arg is a anf node - auto real_arg = utils::cast(real_args[i]); - auto real_arg_output_num = AnfAlgo::GetOutputTensorNum(real_arg); - if (!AnfAlgo::CheckPrimitiveType(real_arg, prim::kPrimTupleGetItem) && real_arg_output_num > 1) { - input_index += real_arg_output_num; - continue; - } - if (valid_inputs[input_index]) { - SetChildGraphParameter(real_arg, graph_inputs[input_index]); - } else { - MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString(); - } - input_index++; - } else if (utils::isa(args[i])) { - auto value = utils::cast(args[i]); - MS_EXCEPTION_IF_NULL(value); + auto node = utils::cast(real_arg); + input_index = SetChildGraphInput(to_graph, node, input_index); + } else if (utils::isa(real_arg)) { // arg is a tensor - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); - } - SetChildGraphParameter(value->cast(), graph_inputs[input_index]); - input_index++; + auto value = utils::cast(real_arg); + input_index = SetChildGraphInput(to_graph, value, input_index); + } else if (utils::isa(real_arg)) { + // arg is a VectorRef + auto vec_args = utils::cast(real_arg); + input_index = SetChildGraphInput(to_graph, vec_args, input_index); } else { - MS_LOG(EXCEPTION) << "Unexpected arg type " << args[i].ToString(); + MS_LOG(EXCEPTION) << "Unexpected arg type " << real_arg.ToString(); } } MS_LOG(INFO) << "Finish!"; diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 0b006256a15..eec4e4ea416 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -79,6 +79,10 @@ class AscendSession : public SessionBasic { void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; void RunOpExecTask(const std::shared_ptr &kernel_graph) const; + size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index); + size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); + size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); + // merge execution order list of child graphs void MergeGraphExecOrder(); // insert assion op to sync data bettween different graphs diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index a33e8f7bd6d..8cafcc2ebc9 100755 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -88,7 +88,7 @@ class KernelGraph : public FuncGraph { void set_executable(bool executable) { executable_ = executable; } // set invalid inputs for control sink std::vector *MutableValidInputs() { return &valid_inputs_; } - std::vector ValidInputs() { return valid_inputs_; } + const std::vector &ValidInputs() const { return valid_inputs_; } private: // remove value node form graph diff --git a/mindspore/ccsrc/utils/base_ref.h b/mindspore/ccsrc/utils/base_ref.h index 6e7911d0d92..74ccff8f809 100644 --- a/mindspore/ccsrc/utils/base_ref.h +++ b/mindspore/ccsrc/utils/base_ref.h @@ -228,6 +228,8 @@ T cast(const BaseRef &handle) { class VectorRef : public BaseRef { public: + using value_type = BaseRef; + VectorRef() {} explicit VectorRef(const std::vector &elements) : elements_(elements) {} VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {} @@ -251,6 +253,13 @@ class VectorRef : public BaseRef { return elements_[dim]; } + BaseRef &operator[](const std::size_t &dim) { + if (dim >= size()) { + MS_LOG(EXCEPTION) << "Out of the size of the tuple."; + } + return elements_[dim]; + } + uint32_t type() const override { return tid(); } std::string ToString() const override; std::vector &elements() { return elements_; } diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index d754667ccee..caf4eb3ee3f 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -143,6 +143,66 @@ void MsBackend::SetSwitchGraph() { } } +// convert node from formal parameter to actual parameter, +// and actual parameter is graph user's formal parameter. +// get top while graph's parameter in recall while. +AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + std::unordered_map params_index; + auto result = node; + auto graph = result->func_graph(); + while (func_graph != graph) { + auto iter = graph_user_inputs_.find(graph); + if (iter == graph_user_inputs_.end()) { + break; + } + + params_index.clear(); + auto ¶ms = graph->parameters(); + for (size_t i = 0; i < params.size(); ++i) { + params_index[params[i]] = i; + } + + graph = iter->second.first; + auto &inputs = iter->second.second; + result = inputs[params_index[result]]; + } + return result; +} + +void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user, + const AnfNodePtrList &inputs) { + if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) { + return; + } + graph_user_inputs_[func_graph] = {user, inputs}; +} + +void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) { + std::unordered_map params_index; + auto ¶ms = func_graph->parameters(); + for (size_t i = 0; i < params.size(); ++i) { + params_index[params[i]] = i; + } + + // recall all child graphs in this while + auto &graph_inputs = graph_inputs_[c]; + for (auto &iter : graph_inputs) { + auto &graph = iter.first; + auto &old_args = iter.second; + auto &result = graph_id_map_[graph]; + auto &inputs = result.inputs; + for (size_t i = 0; i < inputs.size(); ++i) { + auto input = ConvertGraphInput(func_graph, inputs[i]); + auto it = params_index.find(input); + if (it != params_index.end()) { + old_args[i] = args[it->second]; + } + } + sess_->SetChildGraphInput(graph, old_args); + } + graph_inputs_.erase(c); +} + // compile set input output VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { MS_LOG(DEBUG) << "set graph input:" << g; @@ -150,13 +210,20 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { sess_->SetChildGraphInput(g, args); if (is_switch_call_) { - bool curr_cond = simu_cond_map_[curr_switch_].curr_cond; - MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond; - if (0 == simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) { - MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g; - simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g; - SetSwitchGraph(); + if (!curr_switch_.is_null()) { + // push this {g, args} to all user while graph_inputs for nest while, + // when current condition recall over delete this cond in graph_inputs. + for (auto &iter : graph_inputs_) { + iter.second.push_back({g, args}); + } + if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) { + graph_inputs_[curr_switch_].push_back({g, args}); + } } + bool curr_cond = simu_cond_map_[curr_switch_].curr_cond; + MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g; + simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g; + SetSwitchGraph(); } std::vector outputs; @@ -205,42 +272,17 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { return outputs; } -void MsBackend::SetSimuCondFlag(const BaseRef &c, int flag) { - MS_LOG(DEBUG) << "while set cond :" << c.ToString() << ", " << simu_cond_map_.size(); - - if (simu_cond_map_.find(c) == simu_cond_map_.end()) { - MS_LOG(EXCEPTION) << "error c not find"; - } - simu_cond_map_[c].flag = flag; -} - -int MsBackend::GetSimuCondFlag(const BaseRef &c) { - BaseRef cond = c; - if (cond.is_null()) { - MS_LOG(DEBUG) << "get curr_switch"; - cond = curr_switch_; - } - if (simu_cond_map_.find(cond) == simu_cond_map_.end()) { - MS_LOG(ERROR) << "error c not find"; - return -1; - } - return simu_cond_map_[cond].flag; -} - SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) { MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size(); CondGraph cond_graph; cond_graph.curr_cond = value; if (simu_cond_map_.find(c) == simu_cond_map_.end()) { - cond_graph.flag = 0; simu_cond_map_[c] = cond_graph; } if (simu_cond_map_[c].cond_graph_map.count(value)) { - if (value == true) { - return kCondAlreadyRun; - } + return kCondAlreadyRun; } simu_cond_map_[c].curr_cond = value; MS_LOG(DEBUG) << "end set cond "; diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index b950e7adcbf..769dab473ed 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -16,9 +16,11 @@ #ifndef MINDSPORE_CCSRC_VM_BACKEND_H_ #define MINDSPORE_CCSRC_VM_BACKEND_H_ -#include +#include #include +#include #include +#include #include "ir/anf.h" #include "vm/segment_runner.h" @@ -45,6 +47,8 @@ class Backend { virtual bool GetCond(const BaseRef &c, bool *value); virtual void SetSwitchGraph() {} virtual void SetSwitchActive(const BaseRef &, bool) {} + virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} + virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {} void set_curr_switch(const BaseRef &value) { curr_switch_ = value; @@ -54,8 +58,6 @@ class Backend { BaseRef curr_switch() { return curr_switch_; } virtual void Link(GraphId) {} virtual LinConvertResult GetMultiGraphRun(const FuncGraphPtr &) { return LinConvertResult(); } - virtual void SetSimuCondFlag(const BaseRef &, int) {} - virtual int GetSimuCondFlag(const BaseRef &) { return 0; } LinConvertResult multi_result() { return multi_result_; } void set_multi_result(const LinConvertResult &value) { multi_result_ = value; } @@ -75,11 +77,11 @@ class Backend { bool simu_flag_; LinConvertResult multi_result_; AnfNodePtr final_output_; + std::unordered_map> graph_user_inputs_; }; struct CondGraph { bool curr_cond; - int flag; std::unordered_map cond_graph_map; }; @@ -97,15 +99,17 @@ class MsBackend : public Backend { void SetSwitchGraph() override; void SetSwitchActive(const BaseRef &c, bool cond) override; + void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) override; + void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) override; void Link(GraphId) override; + AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &); LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override; - void SetSimuCondFlag(const BaseRef &c, int flag) override; - int GetSimuCondFlag(const BaseRef &c) override; private: session::SessionPtr sess_; std::unordered_map simu_cond_map_; std::unordered_map graph_id_map_; + std::unordered_map>, BaseRefHash> graph_inputs_; }; } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 1c3c917daef..b14bf548699 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -390,6 +390,16 @@ void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) { void CompileGraph::AddPartial(const CNodePtr &node) { auto inputs = node->inputs(); VectorRef args; + auto fn = inputs[1]; + if (!IsValueNode(fn)) { + MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph"; + } + if (backend_->is_multi_graph_sink()) { + auto func_graph = GetValueNode(fn); + args.emplace_back(func_graph); + AnfNodePtrList outs(inputs.begin() + 2, inputs.end()); + backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs); + } for (size_t i = 1; i < inputs.size(); i++) { args.emplace_back(Ref(inputs[i])); } @@ -442,12 +452,17 @@ void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) } int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) { - auto node_inputs = node->inputs(); - AnfNodePtr fn = node_inputs[0]; + auto inputs = node->inputs(); + AnfNodePtr fn = inputs[0]; + if (backend_->is_multi_graph_sink() && IsValueNode(fn)) { + auto func_graph = GetValueNode(fn); + AnfNodePtrList outs(inputs.begin() + 1, inputs.end()); + backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs); + } (void)Ref(fn); - size_t size = node_inputs.size(); + size_t size = inputs.size(); for (size_t i = size - 1; i > 0; i--) { - AddInput(node_inputs[i]); + AddInput(inputs[i]); } if (node == graph->output()) { AddTailCall(fn, size); diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index a897c72f8f3..cf52aafdfef 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -32,7 +32,8 @@ namespace compile { // Arguments: // fn_: Callable function. // args_: Sequence of function args. -StructPartial::StructPartial(int fn, const VectorRef &args) : fn_(fn), args_(args) {} +// fg_: Graph of function. +StructPartial::StructPartial(int fn, const VectorRef &args, const FuncGraphPtr &fg) : fn_(fn), args_(args), fg_(fg) {} std::ostream &operator<<(std::ostream &os, const StructPartial &other) { os << "partial(" << other.fn_ << ", " << other.args_.ToString() << ")"; @@ -40,7 +41,7 @@ std::ostream &operator<<(std::ostream &os, const StructPartial &other) { } bool operator==(const StructPartial &lhs, const StructPartial &rhs) { - return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_); + return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_ && lhs.fg_ == rhs.fg_); } StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {} @@ -242,16 +243,6 @@ void FinalVM::InstTailCall(const VectorRef &args) { int nargs = utils::cast(args[2]); auto new_jmp = Ref(jmp); - - if (backend_->simu_flag()) { - if (backend_->GetSimuCondFlag(BaseRef()) == 2) { - MS_LOG(DEBUG) << "invoke while call tail first"; - Pop(height); - Push(1); - Popp(); - return; - } - } MoveStack(nargs, height); MS_LOG(DEBUG) << "TailCall pushp:" << pc_ << ", jmp:" << jmp; DoJmp(new_jmp); @@ -291,8 +282,30 @@ void FinalVM::InstReturn(const VectorRef &args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPartial(const VectorRef &args) { - MS_LOG(DEBUG) << "Start"; +void FinalVM::InstSimuPartial(const VectorRef &args) { + const size_t args_size = 2; + if (args.size() < args_size) { + MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is " + << args.size() << "."; + return; + } + + auto &node = args[0]; + if (!utils::isa(node)) { + MS_LOG(ERROR) << "The type of 1st input of node must be FuncGraph"; + return; + } + auto fg = utils::cast(node); + int fn_ = utils::cast(args[1]); + auto fn = utils::cast(Ref(fn_)); + MS_LOG(DEBUG) << "Partial argssize:" << args.size(); + std::vector outs(args.size() - 2); + (void)std::transform(args.begin() + 2, args.end(), outs.begin(), + [&, this](const BaseRef &a) { return Ref(utils::cast(a)); }); + Push(std::make_shared(fn, VectorRef(outs), fg)); +} + +void FinalVM::InstRealPartial(const VectorRef &args) { const size_t args_size = 1; if (args.size() < args_size) { MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is " @@ -304,10 +317,18 @@ void FinalVM::InstPartial(const VectorRef &args) { auto fn = utils::cast(Ref(fn_)); MS_LOG(DEBUG) << "Partial argssize:" << args.size(); std::vector outs(args.size() - 1); - (void)std::transform(args.begin() + 1, args.end(), outs.begin(), [&, this](const BaseRef &a) { return Ref(utils::cast(a)); }); Push(std::make_shared(fn, VectorRef(outs))); +} + +void FinalVM::InstPartial(const VectorRef &args) { + MS_LOG(DEBUG) << "Start"; + if (backend_->is_multi_graph_sink()) { + InstSimuPartial(args); + } else { + InstRealPartial(args); + } MS_LOG(DEBUG) << "End"; } @@ -328,43 +349,57 @@ void FinalVM::InstSimuSwitch(const VectorRef &args) { bool bool_value = cond; SwitchCondStatus cond_stat = backend_->SetSimuCond(c, bool_value); - int cond_flag = backend_->GetSimuCondFlag(c); - MS_LOG(DEBUG) << "Simu switch cond:" << cond << ", " << cond_flag << ", " << c.cast()->DebugString(); - if (cond_flag == 2) { - Popp(); - Popp(); - backend_->SetSimuCondFlag(c, 0); - return; - } - if (cond_stat == kCondAlreadyRun) { MS_LOG(DEBUG) << "switch alreay run bool while true jmp"; - if (cond_flag == 0) { - MS_LOG(DEBUG) << "switch second run bool while true jmp"; - backend_->SetSwitchActive(c, true); - Push(std::make_shared(Ref(vtrue), c)); - Pushsp(); - backend_->SetSimuCondFlag(c, 1); - return; - } else if (cond_flag == 1) { - MS_LOG(DEBUG) << "switch first run bool while if jmp"; - Push(std::make_shared(Ref(vfalse), c)); - (void)backend_->SetSimuCond(c, false); - backend_->SetSimuCondFlag(c, 2); - return; - } else { - MS_LOG(EXCEPTION) << "error cond not find"; - return; + BaseRef jmp = Ref(vtrue); + if (utils::isa(jmp)) { + auto new_jmp = utils::cast>(jmp); + backend_->RecallGraphInput(new_jmp->fg_, new_jmp->args_, c); } + cond_jmp_[c] = Ref(vfalse); + Push(static_cast(cond_stat)); + Popp(); + backend_->SetSwitchActive(c, bool_value); + return; } if (bool_value) { Push(std::make_shared(Ref(vtrue), c)); Pushsp(); } else { + MergeJmpArgs(Ref(vfalse), c); Push(std::make_shared(Ref(vfalse), c)); } } +void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) { + auto iter = cond_jmp_.find(c); + if (iter == cond_jmp_.end()) { + return; + } + auto old_jmp = utils::cast>(iter->second); + auto new_jmp = utils::cast>(jmp); + auto &old_args = old_jmp->args_; + auto &new_args = new_jmp->args_; + for (size_t i = 0; i < new_args.size(); ++i) { + auto &old_arg = old_args[i]; + auto &new_arg = new_args[i]; + if (utils::isa(old_arg)) { + auto old_vec_ref = utils::cast(old_arg); + if (utils::isa(new_arg)) { + auto new_vec_ref = utils::cast(new_arg); + std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref)); + } + new_arg = old_vec_ref; + } else if (utils::isa(new_arg)) { + auto new_vec_ref = utils::cast(new_arg); + new_vec_ref.push_back(old_arg); + new_arg = new_vec_ref; + } else { + new_arg = VectorRef({new_arg, old_arg}); + } + } +} + void FinalVM::InstRealSwitch(const VectorRef &args) { const size_t args_size = 3; if (args.size() != args_size) { @@ -399,6 +434,7 @@ void FinalVM::InstSwitch(const VectorRef &args) { } else { InstRealSwitch(args); } + MS_LOG(DEBUG) << "End"; } void FinalVM::InstTuple(const VectorRef &args) { diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index eab726a9b7f..a02eced44c5 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -27,6 +27,9 @@ #include #include #include +#include + +#include "ir/anf.h" #include "utils/base_ref.h" namespace mindspore { @@ -60,13 +63,14 @@ const std::vector inst_str{"call", "tail_call", "return", "partial class StructPartial : public Base { public: // Initialize StructPartial. - StructPartial(int fn, const VectorRef &args); + StructPartial(int fn, const VectorRef &args, const FuncGraphPtr &fg = nullptr); virtual ~StructPartial() = default; MS_DECLARE_PARENT(StructPartial, Base) int fn_; VectorRef args_; + FuncGraphPtr fg_; }; std::ostream &operator<<(std::ostream &os, const StructPartial &other); @@ -98,6 +102,8 @@ class FinalVM { void InstTailCall(const VectorRef &args); void InstReturn(const VectorRef &args); void InstPartial(const VectorRef &args); + void InstSimuPartial(const VectorRef &args); + void InstRealPartial(const VectorRef &args); void InstSwitch(const VectorRef &args); void InstSimuSwitch(const VectorRef &args); void InstRealSwitch(const VectorRef &args); @@ -120,6 +126,7 @@ class FinalVM { void Pushsp(); void Popsp(); void DoJmp(const BaseRef &jmp); + void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); private: InstSet insts_; @@ -128,6 +135,7 @@ class FinalVM { std::stack retsp_; int pc_; int sp_; + std::unordered_map cond_jmp_; BackendPtr backend_; const InstFunctionMap inst_function_map = { {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, diff --git a/tests/st/control/test_multigraph_sink.py b/tests/st/control/test_multigraph_sink.py new file mode 100644 index 00000000000..b2732a63d49 --- /dev/null +++ b/tests/st/control/test_multigraph_sink.py @@ -0,0 +1,184 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_multigraph_sink """ +import pytest +import numpy as np +import mindspore.nn as nn +import mindspore.context as context +from mindspore.common.tensor import Tensor +from mindspore.common import dtype as mstype +from mindspore.common import ms_function +from mindspore.ops import operations as P + + +def setup_module(module): + context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend") + context.set_context(enable_task_sink = True, device_id = 0) + + +c1 = Tensor([2], mstype.int32) +c2 = Tensor([14], mstype.int32) +c3 = Tensor([1], mstype.int32) +c4 = Tensor([0], mstype.int32) +c5 = Tensor([14], mstype.int32) + + +@ms_function +def simple_if(x, y, z): + if x < y: + x = x + 1 + else: + x = x + 2 + x = x + 3 + return x + + +@ms_function +def if_by_if(x, y, z): + if x < y: + x = x + 1 + if y > x: + x = x + 2 + x = x + 3 + return x + + +@ms_function +def if_in_if(x, y, z): + out = c4 + if x < y: + z = c4 + c4 + if z < y: + z = z + 2 + out = out + z + x = x + 3 + out = out + x + return out + + +@ms_function +def simple_while(x, y, z): + y = y + 4 + while x < y: + x = x + 1 + x = x + 3 + return x + + +@ms_function +def while_by_while(x, y, z): + while x < y: + x = x + 1 + while z < c5: + z = z + 1 + x = x + 1 + x = x + 1 + return x + + +@ms_function +def while_in_while(x, y, z): + out = c4 + while x < y: + z = c4 + c4 + while z < y: + z = z + 1 + out = out + z + x = x + 1 + out = out + x + return out + + +@ms_function +def while_by_while_in_while(x, y, z): + out = c4 + while x < c2: + y = c4 + c4 + while y < c2: + y = y + 1 + out = out + y + z = c4 + c4 + while z < c2: + z = z + 1 + out = out + z + x = x + 1 + out = out + x + return out + + +@ms_function +def while_in_while_in_while(x, y, z): + out = c4 + while x < c2: + y = c4 + c4 + while y < c2: + y = y + 1 + z = c4 + c4 + while z < c2: + z = z + 1 + out = out + z + out = out + y + x = x + 1 + out = out + x + return out + + +def test_simple_if(): + output = simple_if(c1, c2, c3) + expect = Tensor([6], mstype.int32) + assert output == expect + + +def test_if_by_if(): + output = if_by_if(c1, c2, c3) + expect = Tensor([8], mstype.int32) + assert output == expect + + +def test_if_in_if(): + output = if_in_if(c1, c2, c3) + expect = Tensor([7], mstype.int32) + assert output == expect + + +def test_simple_while(): + output = simple_while(c1, c2, c3) + expect = Tensor([21], mstype.int32) + assert output == expect + + +def test_while_by_while(): + output = while_by_while(c1, c2, c3) + expect = Tensor([28], mstype.int32) + assert output == expect + + +def test_while_in_while(): + output = while_in_while(c1, c2, c3) + expect = Tensor([1274], mstype.int32) + assert output == expect + + +def test_while_by_while_in_while(): + output = while_by_while_in_while(c1, c2, c3) + expect = Tensor([350], mstype.int32) + assert output == expect + + +def test_while_in_while_in_while(): + output = while_in_while_in_while(c1, c2, c3) + expect = Tensor([2534], mstype.int32) + assert output == expect + diff --git a/tests/ut/python/pynative_mode/test_multigraph_sink.py b/tests/ut/python/pynative_mode/test_multigraph_sink.py new file mode 100644 index 00000000000..0c69c7c2c1a --- /dev/null +++ b/tests/ut/python/pynative_mode/test_multigraph_sink.py @@ -0,0 +1,119 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_multigraph_sink """ +import pytest +import numpy as np +import mindspore.nn as nn +import mindspore.context as context +from mindspore.common.tensor import Tensor +from mindspore.common import dtype as mstype +from mindspore.common import ms_function +from mindspore.ops import operations as P + + +def setup_module(module): + context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend") + context.set_context(enable_task_sink = True, device_id = 0) + + +c1 = Tensor([2], mstype.int32) +c2 = Tensor([14], mstype.int32) +c3 = Tensor([1], mstype.int32) +c4 = Tensor([0], mstype.int32) +c5 = Tensor([14], mstype.int32) + + +@ms_function +def simple_if(x, y, z): + if x < y: + x = x + 1 + else: + x = x + 2 + x = x + 3 + return x + + +@ms_function +def if_by_if(x, y, z): + if x < y: + x = x + 1 + if y > x: + x = x + 2 + x = x + 3 + return x + + +@ms_function +def if_in_if(x, y, z): + out = c4 + if x < y: + z = c4 + c4 + if z < y: + z = z + 2 + out = out + z + x = x + 3 + out = out + x + return out + + +@ms_function +def simple_while(x, y, z): + y = y + 4 + while x < y: + x = x + 1 + x = x + 3 + return x + + +@ms_function +def while_by_while(x, y, z): + while x < y: + x = x + 1 + while z < c5: + z = z + 1 + x = x + 1 + x = x + 1 + return x + + +def test_simple_if(): + output = simple_if(c1, c2, c3) + expect = Tensor([6], mstype.int32) + assert output == expect + + +def test_if_by_if(): + output = if_by_if(c1, c2, c3) + expect = Tensor([8], mstype.int32) + assert output == expect + + +def test_if_in_if(): + output = if_in_if(c1, c2, c3) + expect = Tensor([7], mstype.int32) + assert output == expect + + +def test_simple_while(): + output = simple_while(c1, c2, c3) + expect = Tensor([21], mstype.int32) + assert output == expect + + +def test_while_by_while(): + output = while_by_while(c1, c2, c3) + expect = Tensor([28], mstype.int32) + assert output == expect +