!3110 Delete deprecated codes of ascend control flow
Merge pull request !3110 from zhoufeng/delete-deprecated-codes
This commit is contained in:
commit
c20cd12216
File diff suppressed because it is too large
Load Diff
|
@ -51,26 +51,16 @@ class AscendSession : public SessionBasic {
|
|||
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) override;
|
||||
|
||||
// set parameters of final graph
|
||||
GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &args) override;
|
||||
// set output of final graph
|
||||
void SetFinalGraphOutput(const BaseRef &output) override;
|
||||
// insert switch and set the relative active ops
|
||||
void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override;
|
||||
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
|
||||
void SetChildGraphInput(GraphId g, const VectorRef &args) override;
|
||||
// get graph id in child graphs by ME front anf node pointer
|
||||
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
|
||||
// get graph id of final graph
|
||||
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
|
||||
// insert active to graph
|
||||
void SetActive(GraphId, GraphId) override;
|
||||
// compile child graph when session have multiple child graphs
|
||||
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
||||
void RecurseGetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
|
||||
void GetSummaryNodes(KernelGraph *graph);
|
||||
|
||||
private:
|
||||
void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
|
||||
void SetSummaryNodes(KernelGraph *graph) override;
|
||||
void InitRuntimeResource();
|
||||
void SelectKernel(const KernelGraph &kernel_graph) const;
|
||||
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
@ -92,63 +82,21 @@ class AscendSession : public SessionBasic {
|
|||
void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
|
||||
void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index);
|
||||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index);
|
||||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index);
|
||||
|
||||
void SetFinalGraphOutput(const AnfNodePtr &node);
|
||||
void SetFinalGraphOutput(const ValuePtr &value);
|
||||
void SetFinalGraphOutput(const VectorRef &vec_output);
|
||||
|
||||
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
// split graphs with recurse from root graph
|
||||
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
|
||||
void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
|
||||
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
|
||||
static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
|
||||
static void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
|
||||
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
|
||||
std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
|
||||
const std::vector<CNodePtr> &list);
|
||||
void RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
void RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
AnfNodePtr BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph, const std::vector<CNodePtr> &child_graph_list);
|
||||
|
||||
// merge execution order list of child graphs
|
||||
void MergeGraphExecOrder();
|
||||
// insert assion op to sync data bettween different graphs
|
||||
void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
|
||||
// insert mutiple assigns to graph
|
||||
void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
|
||||
// insert active op to graph
|
||||
void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream);
|
||||
// get execute index of graph
|
||||
size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph);
|
||||
// handle condition graph from vm
|
||||
void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id);
|
||||
// insert depend to graph, used to attch control nodes to graph
|
||||
void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
|
||||
// insert depend to graph, used to attch control nodes to graph
|
||||
void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
|
||||
// set child graph parameter if front arg is a anf
|
||||
void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
|
||||
// set child graph parameter if front arg is a tensor
|
||||
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx);
|
||||
// update the execution order of all child graphs
|
||||
void UpdateGraphOrder(GraphId to_graph);
|
||||
// handle switch when merge
|
||||
void MergeSwitchCompile();
|
||||
// get graph order vector by graph id
|
||||
std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id);
|
||||
const std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id) const;
|
||||
// get graph order type vector by graph id
|
||||
std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id);
|
||||
// copy output of if and else
|
||||
void CopyOutputOfIf(GraphId false_graph_id);
|
||||
const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
|
||||
// check if graph cache exist
|
||||
bool GraphCacheExist(const GraphInfo &graph_info) const;
|
||||
// insert all assign to child graph
|
||||
void InsertAllAssigns();
|
||||
// create fake output of final graph
|
||||
AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
|
||||
// sync intial tensors' data to device
|
||||
void SyncInitialTenosrToDevice();
|
||||
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
@ -162,16 +110,10 @@ class AscendSession : public SessionBasic {
|
|||
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
|
||||
// member variables
|
||||
// key is final_graph_id,value is child graph execute order of final graph
|
||||
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
|
||||
// key is final_graph_id,value is the graph types of child graphs
|
||||
std::unordered_map<GraphId, std::vector<GraphType>> graph_order_types_;
|
||||
// record condition graph of while
|
||||
std::unordered_map<GraphId, GraphId> while_condition_graphs_;
|
||||
// record all conditions
|
||||
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
|
||||
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
|
||||
// share parameters
|
||||
std::vector<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
|
||||
// initial tensors, these tensor will sync data to device before run graph
|
||||
|
|
|
@ -108,7 +108,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
|
|||
kernel_graph->set_execution_order(execution_order);
|
||||
NamedSummaryOutputs summary_outputs;
|
||||
if (enable_summary) {
|
||||
GetSummaryNodes(kernel_graph.get());
|
||||
SetSummaryNodes(kernel_graph.get());
|
||||
summary_outputs = kernel_graph->summary_nodes();
|
||||
runtime_.IncreaseSummaryRefCount(summary_outputs);
|
||||
}
|
||||
|
|
|
@ -217,7 +217,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
|||
Reorder(&execution_order);
|
||||
graph->set_execution_order(execution_order);
|
||||
// Get summary nodes.
|
||||
GetSummaryNodes(graph.get());
|
||||
SetSummaryNodes(graph.get());
|
||||
// Remove NoOp from execution graph
|
||||
opt::RemoveNopNode(graph.get());
|
||||
// Set graph manager.
|
||||
|
|
|
@ -898,27 +898,6 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
|
|||
std::queue<AnfNodePtr> seed_nodes;
|
||||
UpdateNodeEdgeList(&seed_nodes);
|
||||
}
|
||||
// update graph inputs in child graph
|
||||
auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(),
|
||||
[&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
|
||||
return n.first == old_anf_node.get();
|
||||
});
|
||||
if (it_real_inputs != real_inputs_.end()) {
|
||||
// erase old parameter in map
|
||||
auto old_args = it_real_inputs->second;
|
||||
real_inputs_.erase(it_real_inputs);
|
||||
// insert new parameter to map
|
||||
auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(),
|
||||
[&new_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
|
||||
return n.first == new_anf_node.get();
|
||||
});
|
||||
if (iter != real_inputs_.end()) {
|
||||
MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited.";
|
||||
iter->second = old_args;
|
||||
} else {
|
||||
real_inputs_.emplace_back(new_anf_node, old_args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateExecuteKernelStreamLabel() {
|
||||
|
@ -953,56 +932,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
|
|||
return result;
|
||||
}
|
||||
|
||||
void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto iter = std::find_if(
|
||||
real_inputs_.begin(), real_inputs_.end(),
|
||||
[¶meter](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { return n.first == parameter; });
|
||||
if (iter != real_inputs_.end()) {
|
||||
auto &args = iter->second;
|
||||
args.push_back(arg);
|
||||
} else {
|
||||
real_inputs_.emplace_back(parameter, std::vector<AnfNodePtr>(1, arg));
|
||||
}
|
||||
}
|
||||
|
||||
void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph) {
|
||||
unreuse_args_[arg] = from_graph;
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateCallRealInput() {
|
||||
MS_LOG(INFO) << "Update graph id: " << graph_id_;
|
||||
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map;
|
||||
for (auto &it : real_inputs_) {
|
||||
auto parameter = it.first;
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
auto real_inputs = it.second;
|
||||
std::vector<AnfNodePtr> new_real_inputs;
|
||||
for (auto &real_input : real_inputs) {
|
||||
// if real input is a call node ,find the child graph output act as the new real input
|
||||
auto tmp_real_input = GetCallRealOutputs(real_input);
|
||||
std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs));
|
||||
// replace the call in unreuse_args_
|
||||
auto unreuse_arg_it = unreuse_args_.find(real_input);
|
||||
if (unreuse_arg_it != unreuse_args_.end()) {
|
||||
auto old_graph = unreuse_arg_it->second;
|
||||
for (auto new_real_input : new_real_inputs) {
|
||||
// if call reference graph output is parameter, it will be allowed to reuse
|
||||
if (!new_real_input->isa<Parameter>()) {
|
||||
unreuse_args_[new_real_input] = old_graph;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
real_inputs_map.emplace_back(parameter, new_real_inputs);
|
||||
}
|
||||
real_inputs_ = real_inputs_map;
|
||||
}
|
||||
|
||||
void KernelGraph::PrintGraphExecuteOrder() const {
|
||||
MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order";
|
||||
for (size_t i = 0; i < execution_order_.size(); i++) {
|
||||
|
|
|
@ -131,16 +131,8 @@ class KernelGraph : public FuncGraph {
|
|||
void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
|
||||
// find anf node in graph
|
||||
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
|
||||
// get real inputs
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; }
|
||||
void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg);
|
||||
// mark unreused args
|
||||
void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph);
|
||||
const std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> &unreuse_args() const { return unreuse_args_; }
|
||||
// used to dump ir
|
||||
std::string ToString() const override;
|
||||
// update the real input if the node is a call
|
||||
void UpdateCallRealInput();
|
||||
|
||||
void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
|
||||
CNodePtr get_start_label() { return start_label_; }
|
||||
|
@ -212,9 +204,6 @@ class KernelGraph : public FuncGraph {
|
|||
// valid inputs
|
||||
std::vector<bool> valid_inputs_;
|
||||
|
||||
// new members for control sink process
|
||||
// all child grahs refers to partial node
|
||||
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_;
|
||||
// child graph execute order in root graph
|
||||
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
|
||||
|
||||
|
@ -223,9 +212,6 @@ class KernelGraph : public FuncGraph {
|
|||
|
||||
// parameter graph
|
||||
std::shared_ptr<KernelGraph> parent_graph_;
|
||||
// record real parameters,inputs_ is the formal parameters
|
||||
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_;
|
||||
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> unreuse_args_;
|
||||
|
||||
CNodePtr start_label_;
|
||||
CNodePtr end_goto_;
|
||||
|
|
|
@ -890,7 +890,7 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
|
|||
|
||||
void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
|
||||
|
||||
void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
|
||||
void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
|
||||
MS_LOG(DEBUG) << "Update summary Start";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!graph->summary_node_exist()) {
|
||||
|
@ -930,7 +930,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
|||
if (!exist_summary) {
|
||||
return;
|
||||
}
|
||||
GetSummaryNodes(graph);
|
||||
SetSummaryNodes(graph);
|
||||
auto summary_outputs = graph->summary_nodes();
|
||||
std::map<std::string, tensor::TensorPtr> params_list;
|
||||
// fetch outputs apply kernel in session & run callback functions
|
||||
|
|
|
@ -92,19 +92,9 @@ class SessionBasic {
|
|||
CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph);
|
||||
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||
|
||||
// set parameters of final graph
|
||||
virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; }
|
||||
// set output of final graph
|
||||
virtual void SetFinalGraphOutput(const BaseRef &) {}
|
||||
// insert switch and set the relative active ops
|
||||
virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {}
|
||||
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
|
||||
virtual void SetChildGraphInput(GraphId, const VectorRef &) {}
|
||||
// get graph id in child graphs by ME front anf node pointer
|
||||
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
|
||||
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
|
||||
virtual void SetActive(GraphId, GraphId) {}
|
||||
virtual void GetSummaryNodes(KernelGraph *graph);
|
||||
void AssignParamKey(const KernelGraphPtr &kernel_graph);
|
||||
void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
|
||||
virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { return true; }
|
||||
|
@ -120,6 +110,7 @@ class SessionBasic {
|
|||
#endif
|
||||
|
||||
protected:
|
||||
virtual void SetSummaryNodes(KernelGraph *graph);
|
||||
// Get graph by graph id ,if not exist return null ptr
|
||||
KernelGraphPtr GetGraph(GraphId graph_id) const;
|
||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base.cc" "pynative_execute.cc")
|
||||
file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute.cc")
|
||||
|
||||
if (ENABLE_GE)
|
||||
file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc")
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/callbacks.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "utils/base_ref_extends.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "common/utils.h"
|
||||
|
@ -34,19 +33,6 @@ namespace compile {
|
|||
bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
|
||||
bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
|
||||
|
||||
LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
|
||||
// multi_graph merge to one, big graph have paramters in begin and only have one output
|
||||
MS_LOG(DEBUG) << "graph:" << g->ToString() << " parameter size:" << g->parameters().size();
|
||||
multi_result_.inputs = g->parameters();
|
||||
final_output_ = NewValueNode("fake_output");
|
||||
multi_result_.outputs = {final_output_};
|
||||
GraphId final_g = target_sess_->GetFinalRunGraph();
|
||||
|
||||
multi_result_.run = std::make_shared<RunFunc>(
|
||||
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); });
|
||||
return multi_result_;
|
||||
}
|
||||
|
||||
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) {
|
||||
MS_LOG(DEBUG) << "MsConvert";
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
|
@ -96,149 +82,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
|
|||
return result;
|
||||
}
|
||||
|
||||
void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
|
||||
GraphId active_g = simu_cond_map_[c].cond_graph_map[cond];
|
||||
|
||||
GraphId cond_g = kInvalidGraphId;
|
||||
if (utils::isa<AnfNodePtr>(c)) {
|
||||
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
|
||||
}
|
||||
auto before_cond = curr_switch_;
|
||||
if (curr_switch_.hash() != c.hash()) {
|
||||
// invoke while false->before true call
|
||||
if (simu_cond_map_[before_cond].cond_graph_map.count(false)) {
|
||||
active_g = simu_cond_map_[before_cond].cond_graph_map[false];
|
||||
} else {
|
||||
active_g = kInvalidGraphId;
|
||||
}
|
||||
// while x < y:
|
||||
// z = y + 1
|
||||
// while z < c2:
|
||||
// out = out + 1
|
||||
// z = z + 1
|
||||
if (active_g == cond_g) {
|
||||
active_g = kInvalidGraphId;
|
||||
simu_cond_map_[before_cond].cond_graph_map[false] = kInvalidGraphId;
|
||||
}
|
||||
MS_LOG(DEBUG) << "invoke set active:" << active_g;
|
||||
}
|
||||
MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
|
||||
target_sess_->SetActive(active_g, cond_g);
|
||||
}
|
||||
|
||||
void MsBackend::SetSwitchGraph() {
|
||||
MS_LOG(DEBUG) << "SetSwitchGraph curr_switch:" << curr_switch_.ToString();
|
||||
|
||||
if (is_switch_call_) {
|
||||
GraphId false_g = kInvalidGraphId;
|
||||
GraphId true_g = kInvalidGraphId;
|
||||
MS_LOG(DEBUG) << "start SetSwitchGraph";
|
||||
true_g = simu_cond_map_[curr_switch_].cond_graph_map[true];
|
||||
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
|
||||
if (!curr_cond) {
|
||||
if (simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) {
|
||||
// has false branch
|
||||
false_g = simu_cond_map_[curr_switch_].cond_graph_map[false];
|
||||
}
|
||||
GraphId cond_g = kInvalidGraphId;
|
||||
if (utils::isa<AnfNodePtr>(curr_switch_)) {
|
||||
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
|
||||
target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
|
||||
}
|
||||
is_switch_call_ = false;
|
||||
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
|
||||
}
|
||||
}
|
||||
|
||||
// convert node from formal parameter to actual parameter,
|
||||
// and actual parameter is graph user's formal parameter.
|
||||
// get top while graph's parameter in recall while.
|
||||
AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
std::unordered_map<AnfNodePtr, size_t> params_index;
|
||||
auto result = node;
|
||||
auto graph = result->func_graph();
|
||||
while (func_graph != graph) {
|
||||
auto iter = graph_user_inputs_.find(graph);
|
||||
if (iter == graph_user_inputs_.end()) {
|
||||
break;
|
||||
}
|
||||
|
||||
params_index.clear();
|
||||
auto ¶ms = graph->parameters();
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
params_index[params[i]] = i;
|
||||
}
|
||||
|
||||
graph = iter->second.first;
|
||||
auto &inputs = iter->second.second;
|
||||
result = inputs[params_index[result]];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user,
|
||||
const AnfNodePtrList &inputs) {
|
||||
if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) {
|
||||
return;
|
||||
}
|
||||
graph_user_inputs_[func_graph] = {user, inputs};
|
||||
}
|
||||
|
||||
void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) {
|
||||
std::unordered_map<AnfNodePtr, size_t> params_index;
|
||||
auto ¶ms = func_graph->parameters();
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
params_index[params[i]] = i;
|
||||
}
|
||||
|
||||
// recall all child graphs in this while
|
||||
auto &graph_inputs = graph_inputs_[c];
|
||||
for (auto &iter : graph_inputs) {
|
||||
auto &graph = iter.first;
|
||||
auto &old_args = iter.second;
|
||||
auto &result = graph_id_map_[graph];
|
||||
auto &inputs = result.inputs;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto input = ConvertGraphInput(func_graph, inputs[i]);
|
||||
auto it = params_index.find(input);
|
||||
if (it != params_index.end()) {
|
||||
old_args[i] = args[it->second];
|
||||
}
|
||||
}
|
||||
target_sess_->SetChildGraphInput(graph, old_args);
|
||||
}
|
||||
graph_inputs_.erase(c);
|
||||
}
|
||||
|
||||
// compile set input output
|
||||
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
|
||||
MS_LOG(DEBUG) << "set graph input:" << g;
|
||||
// switch maybe twice
|
||||
target_sess_->SetChildGraphInput(g, args);
|
||||
|
||||
if (is_switch_call_) {
|
||||
if (!curr_switch_.is_null()) {
|
||||
// push this {g, args} to all user while graph_inputs for nest while,
|
||||
// when current condition recall over delete this cond in graph_inputs.
|
||||
for (auto &iter : graph_inputs_) {
|
||||
iter.second.push_back({g, args});
|
||||
}
|
||||
if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) {
|
||||
graph_inputs_[curr_switch_].push_back({g, args});
|
||||
}
|
||||
}
|
||||
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
|
||||
MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g;
|
||||
simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g;
|
||||
SetSwitchGraph();
|
||||
}
|
||||
|
||||
std::vector<BaseRef> outputs;
|
||||
(void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
|
||||
[](const AnfNodePtr &v) { return v; });
|
||||
|
@ -290,36 +136,6 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
|
|||
return outputs;
|
||||
}
|
||||
|
||||
SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) {
|
||||
MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size();
|
||||
|
||||
CondGraph cond_graph;
|
||||
cond_graph.curr_cond = value;
|
||||
if (simu_cond_map_.find(c) == simu_cond_map_.end()) {
|
||||
simu_cond_map_[c] = cond_graph;
|
||||
}
|
||||
|
||||
if (simu_cond_map_[c].cond_graph_map.count(value)) {
|
||||
return kCondAlreadyRun;
|
||||
}
|
||||
simu_cond_map_[c].curr_cond = value;
|
||||
MS_LOG(DEBUG) << "end set cond ";
|
||||
return kCondOk;
|
||||
}
|
||||
|
||||
void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
|
||||
MS_LOG(DEBUG) << "Simulate run,root:" << root->ToString() << ", " << root->parameters().size();
|
||||
std::vector<BaseRef> args;
|
||||
auto parameters = root->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
|
||||
[](const AnfNodePtr &v) { return v; });
|
||||
MS_LOG(DEBUG) << "Simulate start";
|
||||
(void)target_sess_->SetFinalGraphInput(parameters);
|
||||
BaseRef output = rt->Eval(VectorRef(args));
|
||||
target_sess_->SetFinalGraphOutput(output);
|
||||
MS_LOG(DEBUG) << "Simulate Eval end";
|
||||
}
|
||||
|
||||
void MsBackend::Link(GraphId graph_id) {
|
||||
if (graph_id == kInvalidGraphId) {
|
||||
graph_id = target_sess_->GetFinalRunGraph();
|
||||
|
@ -330,9 +146,6 @@ void MsBackend::Link(GraphId graph_id) {
|
|||
Backend::Backend(const std::string &name) : name_(name) {
|
||||
MS_LOG(DEBUG) << "select backend:" << name;
|
||||
convert_fn_ = backends[name_];
|
||||
is_switch_call_ = false;
|
||||
is_multi_graph_sink_ = false;
|
||||
simu_flag_ = false;
|
||||
}
|
||||
|
||||
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
||||
|
|
|
@ -43,50 +43,19 @@ class Backend {
|
|||
|
||||
LinkFuncType convert_fn() { return convert_fn_; }
|
||||
std::string name() { return name_; }
|
||||
virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {}
|
||||
virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; }
|
||||
virtual bool GetCond(const BaseRef &c, bool *value);
|
||||
virtual bool GetIndex(const BaseRef &c, int *value);
|
||||
virtual void SetSwitchGraph() {}
|
||||
virtual void SetSwitchActive(const BaseRef &, bool) {}
|
||||
virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {}
|
||||
virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {}
|
||||
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
|
||||
void set_curr_switch(const BaseRef &value) {
|
||||
curr_switch_ = value;
|
||||
is_switch_call_ = true;
|
||||
}
|
||||
|
||||
BaseRef curr_switch() { return curr_switch_; }
|
||||
virtual void Link(GraphId) {}
|
||||
virtual LinConvertResult GetMultiGraphRun(const FuncGraphPtr &) { return LinConvertResult(); }
|
||||
virtual void SetDebugger() {}
|
||||
|
||||
LinConvertResult multi_result() { return multi_result_; }
|
||||
void set_multi_result(const LinConvertResult &value) { multi_result_ = value; }
|
||||
AnfNodePtr final_output() const { return final_output_; }
|
||||
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
|
||||
void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
|
||||
bool simu_flag() const { return simu_flag_; }
|
||||
bool is_switch_call() const { return is_switch_call_; }
|
||||
void set_simu_flag(bool simu) { simu_flag_ = simu; }
|
||||
|
||||
virtual void SetDebugger() {}
|
||||
|
||||
protected:
|
||||
std::string name_;
|
||||
LinkFuncType convert_fn_;
|
||||
BaseRef curr_switch_; // curr switch node
|
||||
bool is_multi_graph_sink_;
|
||||
bool is_switch_call_;
|
||||
bool simu_flag_;
|
||||
LinConvertResult multi_result_;
|
||||
AnfNodePtr final_output_;
|
||||
std::unordered_map<FuncGraphPtr, std::pair<FuncGraphPtr, AnfNodePtrList>> graph_user_inputs_;
|
||||
};
|
||||
|
||||
struct CondGraph {
|
||||
bool curr_cond;
|
||||
std::unordered_map<bool, GraphId> cond_graph_map;
|
||||
};
|
||||
|
||||
class MsBackend : public Backend {
|
||||
|
@ -98,16 +67,7 @@ class MsBackend : public Backend {
|
|||
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
|
||||
|
||||
VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args);
|
||||
void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override;
|
||||
SwitchCondStatus SetSimuCond(const BaseRef &c, bool value) override;
|
||||
|
||||
void SetSwitchGraph() override;
|
||||
void SetSwitchActive(const BaseRef &c, bool cond) override;
|
||||
void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) override;
|
||||
void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) override;
|
||||
void Link(GraphId) override;
|
||||
AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &);
|
||||
LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override;
|
||||
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
|
||||
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
|
||||
void CreateOtherSession(const std::string &target);
|
||||
|
@ -121,9 +81,7 @@ class MsBackend : public Backend {
|
|||
session::SessionPtr other_sess_;
|
||||
std::string target_device_;
|
||||
std::string other_device_;
|
||||
std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_;
|
||||
std::unordered_map<GraphId, LinConvertResult> graph_id_map_;
|
||||
std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_;
|
||||
};
|
||||
} // namespace compile
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -515,11 +515,7 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no
|
|||
MS_LOG(DEBUG) << "LinConvert start";
|
||||
LinConvertResult result;
|
||||
|
||||
if (backend_->simu_flag()) {
|
||||
result = backend_->GetMultiGraphRun(graph);
|
||||
} else {
|
||||
result = lin_convert_(node_list, target);
|
||||
}
|
||||
result = lin_convert_(node_list, target);
|
||||
|
||||
if (result.run == nullptr) {
|
||||
MS_LOG(ERROR) << "LinConvert failed";
|
||||
|
@ -546,27 +542,6 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no
|
|||
return RET_SUCCESS;
|
||||
}
|
||||
|
||||
void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
|
||||
MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString();
|
||||
if (backend_->is_multi_graph_sink()) {
|
||||
VectorRef args;
|
||||
args.emplace_back(-1);
|
||||
MS_LOG(DEBUG) << "call::" << height_;
|
||||
AddInst(Instruction::kCall, args);
|
||||
|
||||
args.clear();
|
||||
args.emplace_back(node->input(1));
|
||||
AddInst(Instruction::kSwitchReturn, args);
|
||||
|
||||
args.clear();
|
||||
args.emplace_back(false);
|
||||
args.emplace_back(Ref(node->input(1)));
|
||||
args.emplace_back(Ref(node->input(2)));
|
||||
args.emplace_back(Ref(node->input(3)));
|
||||
AddInst(Instruction::kSwitch, args);
|
||||
}
|
||||
}
|
||||
|
||||
int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
|
||||
|
@ -589,7 +564,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
|
|||
AddPartial(node);
|
||||
} else if (IsPrimitive(fn, prim::kPrimSwitch)) {
|
||||
AddSwitch(node);
|
||||
AddSinkSwitch(node);
|
||||
} else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
|
||||
AddSwitchLayer(node);
|
||||
} else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
|
||||
|
@ -607,14 +581,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
|
|||
return RET_SUCCESS;
|
||||
}
|
||||
|
||||
void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) {
|
||||
auto ret = LinConvert(graph, {});
|
||||
if (ret == RET_FAILED) {
|
||||
MS_LOG(EXCEPTION) << "MultiGraphRun failed.";
|
||||
}
|
||||
AddReturn(nullptr);
|
||||
}
|
||||
|
||||
bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
|
||||
MS_LOG(DEBUG) << "Start split graph";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -659,11 +625,6 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) {
|
||||
InstSet inst = Run(graph);
|
||||
return inst;
|
||||
}
|
||||
|
||||
InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
|
@ -672,12 +633,8 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
|
|||
int param_height = height_;
|
||||
MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
|
||||
|
||||
if (backend_->simu_flag()) {
|
||||
GenMultiGraphsRun(graph);
|
||||
} else {
|
||||
if (!SplitGraph(graph)) {
|
||||
return inst_;
|
||||
}
|
||||
if (!SplitGraph(graph)) {
|
||||
return inst_;
|
||||
}
|
||||
|
||||
AddPadStack(param_height);
|
||||
|
@ -712,12 +669,6 @@ void CompileGraph::AddPartial(const CNodePtr &node) {
|
|||
if (!IsValueNode<FuncGraph>(fn)) {
|
||||
MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
|
||||
}
|
||||
if (backend_->is_multi_graph_sink()) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(fn);
|
||||
args.emplace_back(func_graph);
|
||||
AnfNodePtrList outs(inputs.begin() + 2, inputs.end());
|
||||
backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
|
||||
}
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
args.emplace_back(Ref(inputs[i]));
|
||||
}
|
||||
|
@ -739,9 +690,6 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
|
|||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
|
||||
}
|
||||
VectorRef args;
|
||||
if (backend_->is_multi_graph_sink()) {
|
||||
args.emplace_back(true);
|
||||
}
|
||||
args.emplace_back(Ref(inputs[1]));
|
||||
args.emplace_back(Ref(inputs[2]));
|
||||
args.emplace_back(Ref(inputs[3]));
|
||||
|
@ -761,11 +709,7 @@ void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
|
|||
|
||||
void CompileGraph::AddReturn(const CNodePtr &node) {
|
||||
VectorRef args;
|
||||
if (backend_->simu_flag()) {
|
||||
args.emplace_back(Ref(backend_->final_output()));
|
||||
} else {
|
||||
args.emplace_back(Ref(node->input(1)));
|
||||
}
|
||||
args.emplace_back(Ref(node->input(1)));
|
||||
args.emplace_back(height_);
|
||||
AddInst(Instruction::kReturn, args);
|
||||
}
|
||||
|
@ -783,11 +727,6 @@ void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim)
|
|||
int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
|
||||
auto inputs = node->inputs();
|
||||
AnfNodePtr fn = inputs[0];
|
||||
if (backend_->is_multi_graph_sink() && IsValueNode<FuncGraph>(fn)) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(fn);
|
||||
AnfNodePtrList outs(inputs.begin() + 1, inputs.end());
|
||||
backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
|
||||
}
|
||||
(void)Ref(fn);
|
||||
size_t size = inputs.size();
|
||||
for (size_t i = size - 1; i > 0; i--) {
|
||||
|
@ -929,17 +868,6 @@ FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) {
|
|||
}
|
||||
|
||||
FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
|
||||
if (backend_->is_multi_graph_sink()) {
|
||||
backend_->set_simu_flag(true);
|
||||
MS_LOG(DEBUG) << "Start simulate";
|
||||
backend_->SimulateRun(rt, graph);
|
||||
MS_LOG(DEBUG) << "Link graphs";
|
||||
insts_ = transform_->GenMultiGraphsSinkInst(graph);
|
||||
rt->set_insts(insts_);
|
||||
backend_->set_simu_flag(false);
|
||||
MS_LOG(DEBUG) << "End start simulate";
|
||||
backend_->Link(kInvalidGraphId);
|
||||
}
|
||||
MS_LOG(DEBUG) << "End";
|
||||
return rt;
|
||||
}
|
||||
|
|
|
@ -54,12 +54,10 @@ class CompileGraph {
|
|||
~CompileGraph() = default;
|
||||
|
||||
InstSet Run(const FuncGraphPtr &func_graph);
|
||||
InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph);
|
||||
bool IsCut(const AnfNodePtr &node);
|
||||
void Push(const AnfNodePtr &node);
|
||||
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
|
||||
void Ret(int nargs);
|
||||
void GenMultiGraphsRun(const FuncGraphPtr &graph);
|
||||
int Ref(const AnfNodePtr &node);
|
||||
VectorRef SplitNodes(const FuncGraphPtr &func_graph);
|
||||
|
||||
|
@ -84,7 +82,6 @@ class CompileGraph {
|
|||
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
|
||||
int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
|
||||
int AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
|
||||
void AddSinkSwitch(const CNodePtr &node);
|
||||
void AddPadStack(int param_height);
|
||||
void AddTailCall(const AnfNodePtr &fn, size_t size);
|
||||
void AddPartial(const CNodePtr &node);
|
||||
|
|
|
@ -17,12 +17,9 @@
|
|||
*/
|
||||
|
||||
#include "vm/vm.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "vm/vmimpl.h"
|
||||
#include "vm/backend.h"
|
||||
#include "vm/transform.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "utils/base_ref_extends.h"
|
||||
|
||||
|
@ -142,33 +139,10 @@ void FinalVM::Popsp() {
|
|||
}
|
||||
}
|
||||
|
||||
void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); }
|
||||
|
||||
bool FinalVM::PopStatus() {
|
||||
if (ret_status_.empty()) {
|
||||
return false;
|
||||
}
|
||||
bool status = ret_status_.top();
|
||||
ret_status_.pop();
|
||||
return status;
|
||||
}
|
||||
|
||||
void FinalVM::DoJmp(const BaseRef &jmp_orig) {
|
||||
MS_LOG(DEBUG) << "Start";
|
||||
|
||||
BaseRef jmp = jmp_orig;
|
||||
if (backend_->simu_flag()) {
|
||||
bool is_switch_call = false;
|
||||
if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base
|
||||
MS_LOG(DEBUG) << "Start jump StructSwitch";
|
||||
auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp);
|
||||
jmp = simu_value->fn_;
|
||||
backend_->set_curr_switch(simu_value->value_);
|
||||
is_switch_call = true;
|
||||
}
|
||||
PushStatus(is_switch_call);
|
||||
}
|
||||
|
||||
if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
|
||||
MS_LOG(DEBUG) << "Start jump StructPartial";
|
||||
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
|
||||
|
@ -270,13 +244,6 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
|
|||
MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
|
||||
return;
|
||||
}
|
||||
|
||||
auto rv = Ref(-1);
|
||||
if (utils::isa<AnfNodePtr>(rv) || utils::isa<VectorRef>(rv)) {
|
||||
auto &c = args[0];
|
||||
cond_out_[c] = rv;
|
||||
}
|
||||
|
||||
Pop(1);
|
||||
Popsp();
|
||||
}
|
||||
|
@ -294,51 +261,12 @@ void FinalVM::InstReturn(const VectorRef &args) {
|
|||
int height = utils::cast<int>(args[1]);
|
||||
|
||||
auto rv = Ref(rpos);
|
||||
if (backend_->simu_flag()) {
|
||||
auto c = backend_->curr_switch();
|
||||
auto status = PopStatus();
|
||||
if (status) {
|
||||
auto iter = cond_out_.find(c);
|
||||
if (iter != cond_out_.end()) {
|
||||
rv = MergeArgs(rv, iter->second);
|
||||
cond_out_.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
if (backend_->is_switch_call()) {
|
||||
backend_->SetSwitchGraph();
|
||||
}
|
||||
}
|
||||
|
||||
Pop(height);
|
||||
Push(rv);
|
||||
Popp();
|
||||
MS_LOG(DEBUG) << "End";
|
||||
}
|
||||
|
||||
void FinalVM::InstSimuPartial(const VectorRef &args) {
|
||||
const size_t args_size = 2;
|
||||
if (args.size() < args_size) {
|
||||
MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
|
||||
<< args.size() << ".";
|
||||
return;
|
||||
}
|
||||
|
||||
auto &node = args[0];
|
||||
if (!utils::isa<FuncGraphPtr>(node)) {
|
||||
MS_LOG(ERROR) << "The type of 1st input of node must be FuncGraph";
|
||||
return;
|
||||
}
|
||||
auto fg = utils::cast<FuncGraphPtr>(node);
|
||||
int fn_ = utils::cast<int>(args[1]);
|
||||
auto fn = utils::cast<int>(Ref(fn_));
|
||||
MS_LOG(DEBUG) << "Partial argssize:" << args.size();
|
||||
std::vector<BaseRef> outs(args.size() - 2);
|
||||
(void)std::transform(args.begin() + 2, args.end(), outs.begin(),
|
||||
[&, this](const BaseRef &a) { return Ref(utils::cast<int>(a)); });
|
||||
Push(std::make_shared<StructPartial>(fn, VectorRef(outs), fg));
|
||||
}
|
||||
|
||||
void FinalVM::InstRealPartial(const VectorRef &args) {
|
||||
const size_t args_size = 1;
|
||||
if (args.size() < args_size) {
|
||||
|
@ -358,91 +286,10 @@ void FinalVM::InstRealPartial(const VectorRef &args) {
|
|||
|
||||
void FinalVM::InstPartial(const VectorRef &args) {
|
||||
MS_LOG(DEBUG) << "Start";
|
||||
if (backend_->is_multi_graph_sink()) {
|
||||
InstSimuPartial(args);
|
||||
} else {
|
||||
InstRealPartial(args);
|
||||
}
|
||||
InstRealPartial(args);
|
||||
MS_LOG(DEBUG) << "End";
|
||||
}
|
||||
|
||||
void FinalVM::InstSimuSwitch(const VectorRef &args) {
|
||||
const size_t args_size = 4;
|
||||
if (args.size() != args_size) {
|
||||
MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
|
||||
<< ".";
|
||||
return;
|
||||
}
|
||||
bool cond = utils::cast<bool>(args[0]);
|
||||
int cond_node = utils::cast<int>(args[1]);
|
||||
int vtrue = utils::cast<int>(args[2]);
|
||||
int vfalse = utils::cast<int>(args[3]);
|
||||
|
||||
MS_LOG(DEBUG) << "Simu switch cond:" << cond;
|
||||
BaseRef c = Ref(cond_node);
|
||||
bool bool_value = cond;
|
||||
SwitchCondStatus cond_stat = backend_->SetSimuCond(c, bool_value);
|
||||
|
||||
if (cond_stat == kCondAlreadyRun) {
|
||||
MS_LOG(DEBUG) << "switch alreay run bool while true jmp";
|
||||
BaseRef jmp = Ref(vtrue);
|
||||
if (utils::isa<StructPartial>(jmp)) {
|
||||
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
|
||||
backend_->RecallGraphInput(new_jmp->fg_, new_jmp->args_, c);
|
||||
}
|
||||
cond_jmp_[c] = Ref(vfalse);
|
||||
Push(static_cast<int>(cond_stat));
|
||||
Popp();
|
||||
backend_->SetSwitchActive(c, bool_value);
|
||||
return;
|
||||
}
|
||||
if (bool_value) {
|
||||
Push(std::make_shared<StructSimuSwitch>(Ref(vtrue), c));
|
||||
Pushsp();
|
||||
} else {
|
||||
MergeJmpArgs(Ref(vfalse), c);
|
||||
Push(std::make_shared<StructSimuSwitch>(Ref(vfalse), c));
|
||||
}
|
||||
}
|
||||
|
||||
void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) {
|
||||
auto iter = cond_jmp_.find(c);
|
||||
if (iter == cond_jmp_.end()) {
|
||||
return;
|
||||
}
|
||||
auto old_jmp = utils::cast<std::shared_ptr<StructPartial>>(iter->second);
|
||||
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
|
||||
auto &old_args = old_jmp->args_;
|
||||
auto &new_args = new_jmp->args_;
|
||||
for (size_t i = 0; i < new_args.size(); ++i) {
|
||||
auto &old_arg = old_args[i];
|
||||
auto &new_arg = new_args[i];
|
||||
new_arg = MergeArgs(old_arg, new_arg);
|
||||
}
|
||||
}
|
||||
|
||||
BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) {
|
||||
MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString();
|
||||
if (utils::isa<VectorRef>(first)) {
|
||||
auto old_vec_ref = utils::cast<VectorRef>(first);
|
||||
if (utils::isa<VectorRef>(second)) {
|
||||
auto new_vec_ref = utils::cast<VectorRef>(second);
|
||||
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
|
||||
} else {
|
||||
old_vec_ref.push_back(second);
|
||||
}
|
||||
return old_vec_ref;
|
||||
}
|
||||
|
||||
if (utils::isa<VectorRef>(second)) {
|
||||
auto new_vec_ref = utils::cast<VectorRef>(second);
|
||||
new_vec_ref.push_back(first);
|
||||
return new_vec_ref;
|
||||
}
|
||||
|
||||
return VectorRef({first, second});
|
||||
}
|
||||
|
||||
void FinalVM::InstRealSwitch(const VectorRef &args) {
|
||||
const size_t args_size = 3;
|
||||
if (args.size() != args_size) {
|
||||
|
@ -472,11 +319,7 @@ void FinalVM::InstRealSwitch(const VectorRef &args) {
|
|||
|
||||
void FinalVM::InstSwitch(const VectorRef &args) {
|
||||
MS_LOG(DEBUG) << "Start";
|
||||
if (backend_->is_multi_graph_sink()) {
|
||||
InstSimuSwitch(args);
|
||||
} else {
|
||||
InstRealSwitch(args);
|
||||
}
|
||||
InstRealSwitch(args);
|
||||
MS_LOG(DEBUG) << "End";
|
||||
}
|
||||
|
||||
|
@ -580,14 +423,6 @@ void FinalVM::InstExternal(const VectorRef &args) {
|
|||
VectorRef tuple;
|
||||
RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]);
|
||||
compile::RunFuncPtr fn = run_ref.func_;
|
||||
if (backend_->simu_flag()) {
|
||||
MS_LOG(DEBUG) << "Simu run";
|
||||
if (args.size() == 1) {
|
||||
MS_LOG(EXCEPTION) << "The number of args should be greater than 1, but got 1";
|
||||
}
|
||||
auto simu_run_ref = utils::cast<RunFunctionRef>(args[1]);
|
||||
fn = simu_run_ref.func_;
|
||||
}
|
||||
for (size_t i = 2; i < args.size(); ++i) {
|
||||
auto index = utils::cast<int>(args[i]);
|
||||
tuple.push_back(Ref(index));
|
||||
|
|
|
@ -96,7 +96,6 @@ class FinalVM {
|
|||
public:
|
||||
// Create a VM with the specified instructions and backend.
|
||||
explicit FinalVM(const InstSet &insts, const BackendPtr &backend);
|
||||
|
||||
virtual ~FinalVM() = default;
|
||||
|
||||
BaseRef Eval(const VectorRef &args);
|
||||
|
@ -104,10 +103,8 @@ class FinalVM {
|
|||
void InstTailCall(const VectorRef &args);
|
||||
void InstReturn(const VectorRef &args);
|
||||
void InstPartial(const VectorRef &args);
|
||||
void InstSimuPartial(const VectorRef &args);
|
||||
void InstRealPartial(const VectorRef &args);
|
||||
void InstSwitch(const VectorRef &args);
|
||||
void InstSimuSwitch(const VectorRef &args);
|
||||
void InstRealSwitch(const VectorRef &args);
|
||||
void InstTuple(const VectorRef &args);
|
||||
void InstPush(const VectorRef &args);
|
||||
|
@ -129,23 +126,16 @@ class FinalVM {
|
|||
void Popp();
|
||||
void Pushsp();
|
||||
void Popsp();
|
||||
void PushStatus(bool is_switch_call);
|
||||
bool PopStatus();
|
||||
void DoJmp(const BaseRef &jmp);
|
||||
void SyncData(const py::object &args);
|
||||
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
|
||||
BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
|
||||
|
||||
private:
|
||||
InstSet insts_;
|
||||
std::deque<BaseRef> insts_stack_;
|
||||
std::stack<int> retp_;
|
||||
std::stack<int> retsp_;
|
||||
std::stack<bool> ret_status_;
|
||||
int pc_;
|
||||
int sp_;
|
||||
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_;
|
||||
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_;
|
||||
BackendPtr backend_;
|
||||
const InstFunctionMap inst_function_map = {
|
||||
{Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
|
||||
|
|
Loading…
Reference in New Issue