!3110 Delete deprecated codes of ascend control flow
Merge pull request !3110 from zhoufeng/delete-deprecated-codes
This commit is contained in:
commit
c20cd12216
File diff suppressed because it is too large
Load Diff
|
@ -51,26 +51,16 @@ class AscendSession : public SessionBasic {
|
||||||
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors) override;
|
const std::vector<tensor::TensorPtr> &input_tensors) override;
|
||||||
|
|
||||||
// set parameters of final graph
|
|
||||||
GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &args) override;
|
|
||||||
// set output of final graph
|
|
||||||
void SetFinalGraphOutput(const BaseRef &output) override;
|
|
||||||
// insert switch and set the relative active ops
|
|
||||||
void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override;
|
|
||||||
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
|
|
||||||
void SetChildGraphInput(GraphId g, const VectorRef &args) override;
|
|
||||||
// get graph id in child graphs by ME front anf node pointer
|
// get graph id in child graphs by ME front anf node pointer
|
||||||
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
|
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
|
||||||
// get graph id of final graph
|
// get graph id of final graph
|
||||||
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
|
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
|
||||||
// insert active to graph
|
|
||||||
void SetActive(GraphId, GraphId) override;
|
|
||||||
// compile child graph when session have multiple child graphs
|
// compile child graph when session have multiple child graphs
|
||||||
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
||||||
void RecurseGetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
|
|
||||||
void GetSummaryNodes(KernelGraph *graph);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
|
||||||
|
void SetSummaryNodes(KernelGraph *graph) override;
|
||||||
void InitRuntimeResource();
|
void InitRuntimeResource();
|
||||||
void SelectKernel(const KernelGraph &kernel_graph) const;
|
void SelectKernel(const KernelGraph &kernel_graph) const;
|
||||||
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
|
@ -92,63 +82,21 @@ class AscendSession : public SessionBasic {
|
||||||
void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
|
void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
|
||||||
void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
|
|
||||||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index);
|
static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
|
||||||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index);
|
static void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
|
||||||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index);
|
|
||||||
|
|
||||||
void SetFinalGraphOutput(const AnfNodePtr &node);
|
|
||||||
void SetFinalGraphOutput(const ValuePtr &value);
|
|
||||||
void SetFinalGraphOutput(const VectorRef &vec_output);
|
|
||||||
|
|
||||||
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims,
|
|
||||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
|
||||||
// split graphs with recurse from root graph
|
|
||||||
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
|
|
||||||
void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
|
|
||||||
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
|
|
||||||
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
|
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
|
||||||
std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
|
|
||||||
const std::vector<CNodePtr> &list);
|
|
||||||
void RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
|
|
||||||
void RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
|
|
||||||
AnfNodePtr BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph, const std::vector<CNodePtr> &child_graph_list);
|
|
||||||
|
|
||||||
// merge execution order list of child graphs
|
// merge execution order list of child graphs
|
||||||
void MergeGraphExecOrder();
|
void MergeGraphExecOrder();
|
||||||
// insert assion op to sync data bettween different graphs
|
// insert assion op to sync data bettween different graphs
|
||||||
void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
|
void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
|
||||||
// insert mutiple assigns to graph
|
|
||||||
void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
|
|
||||||
// insert active op to graph
|
|
||||||
void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream);
|
|
||||||
// get execute index of graph
|
|
||||||
size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph);
|
|
||||||
// handle condition graph from vm
|
|
||||||
void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id);
|
|
||||||
// insert depend to graph, used to attch control nodes to graph
|
|
||||||
void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
|
|
||||||
// insert depend to graph, used to attch control nodes to graph
|
|
||||||
void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
|
|
||||||
// set child graph parameter if front arg is a anf
|
|
||||||
void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
|
|
||||||
// set child graph parameter if front arg is a tensor
|
|
||||||
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx);
|
|
||||||
// update the execution order of all child graphs
|
|
||||||
void UpdateGraphOrder(GraphId to_graph);
|
|
||||||
// handle switch when merge
|
|
||||||
void MergeSwitchCompile();
|
|
||||||
// get graph order vector by graph id
|
// get graph order vector by graph id
|
||||||
std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id);
|
const std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id) const;
|
||||||
// get graph order type vector by graph id
|
// get graph order type vector by graph id
|
||||||
std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id);
|
const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
|
||||||
// copy output of if and else
|
|
||||||
void CopyOutputOfIf(GraphId false_graph_id);
|
|
||||||
// check if graph cache exist
|
// check if graph cache exist
|
||||||
bool GraphCacheExist(const GraphInfo &graph_info) const;
|
bool GraphCacheExist(const GraphInfo &graph_info) const;
|
||||||
// insert all assign to child graph
|
// insert all assign to child graph
|
||||||
void InsertAllAssigns();
|
void InsertAllAssigns();
|
||||||
// create fake output of final graph
|
|
||||||
AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
|
|
||||||
// sync intial tensors' data to device
|
// sync intial tensors' data to device
|
||||||
void SyncInitialTenosrToDevice();
|
void SyncInitialTenosrToDevice();
|
||||||
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
|
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||||
|
@ -162,16 +110,10 @@ class AscendSession : public SessionBasic {
|
||||||
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||||
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||||
|
|
||||||
// member variables
|
|
||||||
// key is final_graph_id,value is child graph execute order of final graph
|
// key is final_graph_id,value is child graph execute order of final graph
|
||||||
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
|
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
|
||||||
// key is final_graph_id,value is the graph types of child graphs
|
// key is final_graph_id,value is the graph types of child graphs
|
||||||
std::unordered_map<GraphId, std::vector<GraphType>> graph_order_types_;
|
std::unordered_map<GraphId, std::vector<GraphType>> graph_order_types_;
|
||||||
// record condition graph of while
|
|
||||||
std::unordered_map<GraphId, GraphId> while_condition_graphs_;
|
|
||||||
// record all conditions
|
|
||||||
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
|
|
||||||
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
|
|
||||||
// share parameters
|
// share parameters
|
||||||
std::vector<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
|
std::vector<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
|
||||||
// initial tensors, these tensor will sync data to device before run graph
|
// initial tensors, these tensor will sync data to device before run graph
|
||||||
|
|
|
@ -108,7 +108,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
|
||||||
kernel_graph->set_execution_order(execution_order);
|
kernel_graph->set_execution_order(execution_order);
|
||||||
NamedSummaryOutputs summary_outputs;
|
NamedSummaryOutputs summary_outputs;
|
||||||
if (enable_summary) {
|
if (enable_summary) {
|
||||||
GetSummaryNodes(kernel_graph.get());
|
SetSummaryNodes(kernel_graph.get());
|
||||||
summary_outputs = kernel_graph->summary_nodes();
|
summary_outputs = kernel_graph->summary_nodes();
|
||||||
runtime_.IncreaseSummaryRefCount(summary_outputs);
|
runtime_.IncreaseSummaryRefCount(summary_outputs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -217,7 +217,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
||||||
Reorder(&execution_order);
|
Reorder(&execution_order);
|
||||||
graph->set_execution_order(execution_order);
|
graph->set_execution_order(execution_order);
|
||||||
// Get summary nodes.
|
// Get summary nodes.
|
||||||
GetSummaryNodes(graph.get());
|
SetSummaryNodes(graph.get());
|
||||||
// Remove NoOp from execution graph
|
// Remove NoOp from execution graph
|
||||||
opt::RemoveNopNode(graph.get());
|
opt::RemoveNopNode(graph.get());
|
||||||
// Set graph manager.
|
// Set graph manager.
|
||||||
|
|
|
@ -898,27 +898,6 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
|
||||||
std::queue<AnfNodePtr> seed_nodes;
|
std::queue<AnfNodePtr> seed_nodes;
|
||||||
UpdateNodeEdgeList(&seed_nodes);
|
UpdateNodeEdgeList(&seed_nodes);
|
||||||
}
|
}
|
||||||
// update graph inputs in child graph
|
|
||||||
auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(),
|
|
||||||
[&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
|
|
||||||
return n.first == old_anf_node.get();
|
|
||||||
});
|
|
||||||
if (it_real_inputs != real_inputs_.end()) {
|
|
||||||
// erase old parameter in map
|
|
||||||
auto old_args = it_real_inputs->second;
|
|
||||||
real_inputs_.erase(it_real_inputs);
|
|
||||||
// insert new parameter to map
|
|
||||||
auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(),
|
|
||||||
[&new_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
|
|
||||||
return n.first == new_anf_node.get();
|
|
||||||
});
|
|
||||||
if (iter != real_inputs_.end()) {
|
|
||||||
MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited.";
|
|
||||||
iter->second = old_args;
|
|
||||||
} else {
|
|
||||||
real_inputs_.emplace_back(new_anf_node, old_args);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelGraph::UpdateExecuteKernelStreamLabel() {
|
void KernelGraph::UpdateExecuteKernelStreamLabel() {
|
||||||
|
@ -953,56 +932,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) {
|
|
||||||
MS_EXCEPTION_IF_NULL(parameter);
|
|
||||||
MS_EXCEPTION_IF_NULL(arg);
|
|
||||||
MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
|
|
||||||
MS_EXCEPTION_IF_NULL(parameter);
|
|
||||||
MS_EXCEPTION_IF_NULL(arg);
|
|
||||||
auto iter = std::find_if(
|
|
||||||
real_inputs_.begin(), real_inputs_.end(),
|
|
||||||
[¶meter](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { return n.first == parameter; });
|
|
||||||
if (iter != real_inputs_.end()) {
|
|
||||||
auto &args = iter->second;
|
|
||||||
args.push_back(arg);
|
|
||||||
} else {
|
|
||||||
real_inputs_.emplace_back(parameter, std::vector<AnfNodePtr>(1, arg));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph) {
|
|
||||||
unreuse_args_[arg] = from_graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
void KernelGraph::UpdateCallRealInput() {
|
|
||||||
MS_LOG(INFO) << "Update graph id: " << graph_id_;
|
|
||||||
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map;
|
|
||||||
for (auto &it : real_inputs_) {
|
|
||||||
auto parameter = it.first;
|
|
||||||
MS_EXCEPTION_IF_NULL(parameter);
|
|
||||||
auto real_inputs = it.second;
|
|
||||||
std::vector<AnfNodePtr> new_real_inputs;
|
|
||||||
for (auto &real_input : real_inputs) {
|
|
||||||
// if real input is a call node ,find the child graph output act as the new real input
|
|
||||||
auto tmp_real_input = GetCallRealOutputs(real_input);
|
|
||||||
std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs));
|
|
||||||
// replace the call in unreuse_args_
|
|
||||||
auto unreuse_arg_it = unreuse_args_.find(real_input);
|
|
||||||
if (unreuse_arg_it != unreuse_args_.end()) {
|
|
||||||
auto old_graph = unreuse_arg_it->second;
|
|
||||||
for (auto new_real_input : new_real_inputs) {
|
|
||||||
// if call reference graph output is parameter, it will be allowed to reuse
|
|
||||||
if (!new_real_input->isa<Parameter>()) {
|
|
||||||
unreuse_args_[new_real_input] = old_graph;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
real_inputs_map.emplace_back(parameter, new_real_inputs);
|
|
||||||
}
|
|
||||||
real_inputs_ = real_inputs_map;
|
|
||||||
}
|
|
||||||
|
|
||||||
void KernelGraph::PrintGraphExecuteOrder() const {
|
void KernelGraph::PrintGraphExecuteOrder() const {
|
||||||
MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order";
|
MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order";
|
||||||
for (size_t i = 0; i < execution_order_.size(); i++) {
|
for (size_t i = 0; i < execution_order_.size(); i++) {
|
||||||
|
|
|
@ -131,16 +131,8 @@ class KernelGraph : public FuncGraph {
|
||||||
void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
|
void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
|
||||||
// find anf node in graph
|
// find anf node in graph
|
||||||
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
|
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
|
||||||
// get real inputs
|
|
||||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; }
|
|
||||||
void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg);
|
|
||||||
// mark unreused args
|
|
||||||
void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph);
|
|
||||||
const std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> &unreuse_args() const { return unreuse_args_; }
|
|
||||||
// used to dump ir
|
// used to dump ir
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
// update the real input if the node is a call
|
|
||||||
void UpdateCallRealInput();
|
|
||||||
|
|
||||||
void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
|
void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
|
||||||
CNodePtr get_start_label() { return start_label_; }
|
CNodePtr get_start_label() { return start_label_; }
|
||||||
|
@ -212,9 +204,6 @@ class KernelGraph : public FuncGraph {
|
||||||
// valid inputs
|
// valid inputs
|
||||||
std::vector<bool> valid_inputs_;
|
std::vector<bool> valid_inputs_;
|
||||||
|
|
||||||
// new members for control sink process
|
|
||||||
// all child grahs refers to partial node
|
|
||||||
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_;
|
|
||||||
// child graph execute order in root graph
|
// child graph execute order in root graph
|
||||||
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
|
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
|
||||||
|
|
||||||
|
@ -223,9 +212,6 @@ class KernelGraph : public FuncGraph {
|
||||||
|
|
||||||
// parameter graph
|
// parameter graph
|
||||||
std::shared_ptr<KernelGraph> parent_graph_;
|
std::shared_ptr<KernelGraph> parent_graph_;
|
||||||
// record real parameters,inputs_ is the formal parameters
|
|
||||||
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_;
|
|
||||||
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> unreuse_args_;
|
|
||||||
|
|
||||||
CNodePtr start_label_;
|
CNodePtr start_label_;
|
||||||
CNodePtr end_goto_;
|
CNodePtr end_goto_;
|
||||||
|
|
|
@ -890,7 +890,7 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
|
||||||
|
|
||||||
void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
|
void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
|
||||||
|
|
||||||
void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
|
void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
|
||||||
MS_LOG(DEBUG) << "Update summary Start";
|
MS_LOG(DEBUG) << "Update summary Start";
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
if (!graph->summary_node_exist()) {
|
if (!graph->summary_node_exist()) {
|
||||||
|
@ -930,7 +930,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
||||||
if (!exist_summary) {
|
if (!exist_summary) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
GetSummaryNodes(graph);
|
SetSummaryNodes(graph);
|
||||||
auto summary_outputs = graph->summary_nodes();
|
auto summary_outputs = graph->summary_nodes();
|
||||||
std::map<std::string, tensor::TensorPtr> params_list;
|
std::map<std::string, tensor::TensorPtr> params_list;
|
||||||
// fetch outputs apply kernel in session & run callback functions
|
// fetch outputs apply kernel in session & run callback functions
|
||||||
|
|
|
@ -92,19 +92,9 @@ class SessionBasic {
|
||||||
CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph);
|
CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph);
|
||||||
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
|
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||||
|
|
||||||
// set parameters of final graph
|
|
||||||
virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; }
|
|
||||||
// set output of final graph
|
|
||||||
virtual void SetFinalGraphOutput(const BaseRef &) {}
|
|
||||||
// insert switch and set the relative active ops
|
|
||||||
virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {}
|
|
||||||
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
|
|
||||||
virtual void SetChildGraphInput(GraphId, const VectorRef &) {}
|
|
||||||
// get graph id in child graphs by ME front anf node pointer
|
// get graph id in child graphs by ME front anf node pointer
|
||||||
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
|
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
|
||||||
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
|
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
|
||||||
virtual void SetActive(GraphId, GraphId) {}
|
|
||||||
virtual void GetSummaryNodes(KernelGraph *graph);
|
|
||||||
void AssignParamKey(const KernelGraphPtr &kernel_graph);
|
void AssignParamKey(const KernelGraphPtr &kernel_graph);
|
||||||
void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
|
void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
|
||||||
virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { return true; }
|
virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { return true; }
|
||||||
|
@ -120,6 +110,7 @@ class SessionBasic {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
virtual void SetSummaryNodes(KernelGraph *graph);
|
||||||
// Get graph by graph id ,if not exist return null ptr
|
// Get graph by graph id ,if not exist return null ptr
|
||||||
KernelGraphPtr GetGraph(GraphId graph_id) const;
|
KernelGraphPtr GetGraph(GraphId graph_id) const;
|
||||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base.cc" "pynative_execute.cc")
|
file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute.cc")
|
||||||
|
|
||||||
if (ENABLE_GE)
|
if (ENABLE_GE)
|
||||||
file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc")
|
file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc")
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "utils/callbacks.h"
|
#include "utils/callbacks.h"
|
||||||
#include "utils/graph_utils.h"
|
|
||||||
#include "utils/base_ref_extends.h"
|
#include "utils/base_ref_extends.h"
|
||||||
#include "backend/session/session_factory.h"
|
#include "backend/session/session_factory.h"
|
||||||
#include "common/utils.h"
|
#include "common/utils.h"
|
||||||
|
@ -34,19 +33,6 @@ namespace compile {
|
||||||
bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
|
bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
|
||||||
bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
|
bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
|
||||||
|
|
||||||
LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
|
|
||||||
// multi_graph merge to one, big graph have paramters in begin and only have one output
|
|
||||||
MS_LOG(DEBUG) << "graph:" << g->ToString() << " parameter size:" << g->parameters().size();
|
|
||||||
multi_result_.inputs = g->parameters();
|
|
||||||
final_output_ = NewValueNode("fake_output");
|
|
||||||
multi_result_.outputs = {final_output_};
|
|
||||||
GraphId final_g = target_sess_->GetFinalRunGraph();
|
|
||||||
|
|
||||||
multi_result_.run = std::make_shared<RunFunc>(
|
|
||||||
[final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); });
|
|
||||||
return multi_result_;
|
|
||||||
}
|
|
||||||
|
|
||||||
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) {
|
LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) {
|
||||||
MS_LOG(DEBUG) << "MsConvert";
|
MS_LOG(DEBUG) << "MsConvert";
|
||||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||||
|
@ -96,149 +82,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
|
|
||||||
GraphId active_g = simu_cond_map_[c].cond_graph_map[cond];
|
|
||||||
|
|
||||||
GraphId cond_g = kInvalidGraphId;
|
|
||||||
if (utils::isa<AnfNodePtr>(c)) {
|
|
||||||
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
|
|
||||||
}
|
|
||||||
auto before_cond = curr_switch_;
|
|
||||||
if (curr_switch_.hash() != c.hash()) {
|
|
||||||
// invoke while false->before true call
|
|
||||||
if (simu_cond_map_[before_cond].cond_graph_map.count(false)) {
|
|
||||||
active_g = simu_cond_map_[before_cond].cond_graph_map[false];
|
|
||||||
} else {
|
|
||||||
active_g = kInvalidGraphId;
|
|
||||||
}
|
|
||||||
// while x < y:
|
|
||||||
// z = y + 1
|
|
||||||
// while z < c2:
|
|
||||||
// out = out + 1
|
|
||||||
// z = z + 1
|
|
||||||
if (active_g == cond_g) {
|
|
||||||
active_g = kInvalidGraphId;
|
|
||||||
simu_cond_map_[before_cond].cond_graph_map[false] = kInvalidGraphId;
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "invoke set active:" << active_g;
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
|
|
||||||
target_sess_->SetActive(active_g, cond_g);
|
|
||||||
}
|
|
||||||
|
|
||||||
void MsBackend::SetSwitchGraph() {
|
|
||||||
MS_LOG(DEBUG) << "SetSwitchGraph curr_switch:" << curr_switch_.ToString();
|
|
||||||
|
|
||||||
if (is_switch_call_) {
|
|
||||||
GraphId false_g = kInvalidGraphId;
|
|
||||||
GraphId true_g = kInvalidGraphId;
|
|
||||||
MS_LOG(DEBUG) << "start SetSwitchGraph";
|
|
||||||
true_g = simu_cond_map_[curr_switch_].cond_graph_map[true];
|
|
||||||
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
|
|
||||||
if (!curr_cond) {
|
|
||||||
if (simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) {
|
|
||||||
// has false branch
|
|
||||||
false_g = simu_cond_map_[curr_switch_].cond_graph_map[false];
|
|
||||||
}
|
|
||||||
GraphId cond_g = kInvalidGraphId;
|
|
||||||
if (utils::isa<AnfNodePtr>(curr_switch_)) {
|
|
||||||
cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
|
|
||||||
target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
|
|
||||||
}
|
|
||||||
is_switch_call_ = false;
|
|
||||||
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert node from formal parameter to actual parameter,
|
|
||||||
// and actual parameter is graph user's formal parameter.
|
|
||||||
// get top while graph's parameter in recall while.
|
|
||||||
AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
|
||||||
std::unordered_map<AnfNodePtr, size_t> params_index;
|
|
||||||
auto result = node;
|
|
||||||
auto graph = result->func_graph();
|
|
||||||
while (func_graph != graph) {
|
|
||||||
auto iter = graph_user_inputs_.find(graph);
|
|
||||||
if (iter == graph_user_inputs_.end()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
params_index.clear();
|
|
||||||
auto ¶ms = graph->parameters();
|
|
||||||
for (size_t i = 0; i < params.size(); ++i) {
|
|
||||||
params_index[params[i]] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
graph = iter->second.first;
|
|
||||||
auto &inputs = iter->second.second;
|
|
||||||
result = inputs[params_index[result]];
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user,
|
|
||||||
const AnfNodePtrList &inputs) {
|
|
||||||
if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
graph_user_inputs_[func_graph] = {user, inputs};
|
|
||||||
}
|
|
||||||
|
|
||||||
void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) {
|
|
||||||
std::unordered_map<AnfNodePtr, size_t> params_index;
|
|
||||||
auto ¶ms = func_graph->parameters();
|
|
||||||
for (size_t i = 0; i < params.size(); ++i) {
|
|
||||||
params_index[params[i]] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// recall all child graphs in this while
|
|
||||||
auto &graph_inputs = graph_inputs_[c];
|
|
||||||
for (auto &iter : graph_inputs) {
|
|
||||||
auto &graph = iter.first;
|
|
||||||
auto &old_args = iter.second;
|
|
||||||
auto &result = graph_id_map_[graph];
|
|
||||||
auto &inputs = result.inputs;
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
||||||
auto input = ConvertGraphInput(func_graph, inputs[i]);
|
|
||||||
auto it = params_index.find(input);
|
|
||||||
if (it != params_index.end()) {
|
|
||||||
old_args[i] = args[it->second];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
target_sess_->SetChildGraphInput(graph, old_args);
|
|
||||||
}
|
|
||||||
graph_inputs_.erase(c);
|
|
||||||
}
|
|
||||||
|
|
||||||
// compile set input output
|
// compile set input output
|
||||||
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
|
VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
|
||||||
MS_LOG(DEBUG) << "set graph input:" << g;
|
MS_LOG(DEBUG) << "set graph input:" << g;
|
||||||
// switch maybe twice
|
|
||||||
target_sess_->SetChildGraphInput(g, args);
|
|
||||||
|
|
||||||
if (is_switch_call_) {
|
|
||||||
if (!curr_switch_.is_null()) {
|
|
||||||
// push this {g, args} to all user while graph_inputs for nest while,
|
|
||||||
// when current condition recall over delete this cond in graph_inputs.
|
|
||||||
for (auto &iter : graph_inputs_) {
|
|
||||||
iter.second.push_back({g, args});
|
|
||||||
}
|
|
||||||
if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) {
|
|
||||||
graph_inputs_[curr_switch_].push_back({g, args});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
|
|
||||||
MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g;
|
|
||||||
simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g;
|
|
||||||
SetSwitchGraph();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<BaseRef> outputs;
|
std::vector<BaseRef> outputs;
|
||||||
(void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
|
(void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
|
||||||
[](const AnfNodePtr &v) { return v; });
|
[](const AnfNodePtr &v) { return v; });
|
||||||
|
@ -290,36 +136,6 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) {
|
|
||||||
MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size();
|
|
||||||
|
|
||||||
CondGraph cond_graph;
|
|
||||||
cond_graph.curr_cond = value;
|
|
||||||
if (simu_cond_map_.find(c) == simu_cond_map_.end()) {
|
|
||||||
simu_cond_map_[c] = cond_graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (simu_cond_map_[c].cond_graph_map.count(value)) {
|
|
||||||
return kCondAlreadyRun;
|
|
||||||
}
|
|
||||||
simu_cond_map_[c].curr_cond = value;
|
|
||||||
MS_LOG(DEBUG) << "end set cond ";
|
|
||||||
return kCondOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
|
|
||||||
MS_LOG(DEBUG) << "Simulate run,root:" << root->ToString() << ", " << root->parameters().size();
|
|
||||||
std::vector<BaseRef> args;
|
|
||||||
auto parameters = root->parameters();
|
|
||||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
|
|
||||||
[](const AnfNodePtr &v) { return v; });
|
|
||||||
MS_LOG(DEBUG) << "Simulate start";
|
|
||||||
(void)target_sess_->SetFinalGraphInput(parameters);
|
|
||||||
BaseRef output = rt->Eval(VectorRef(args));
|
|
||||||
target_sess_->SetFinalGraphOutput(output);
|
|
||||||
MS_LOG(DEBUG) << "Simulate Eval end";
|
|
||||||
}
|
|
||||||
|
|
||||||
void MsBackend::Link(GraphId graph_id) {
|
void MsBackend::Link(GraphId graph_id) {
|
||||||
if (graph_id == kInvalidGraphId) {
|
if (graph_id == kInvalidGraphId) {
|
||||||
graph_id = target_sess_->GetFinalRunGraph();
|
graph_id = target_sess_->GetFinalRunGraph();
|
||||||
|
@ -330,9 +146,6 @@ void MsBackend::Link(GraphId graph_id) {
|
||||||
Backend::Backend(const std::string &name) : name_(name) {
|
Backend::Backend(const std::string &name) : name_(name) {
|
||||||
MS_LOG(DEBUG) << "select backend:" << name;
|
MS_LOG(DEBUG) << "select backend:" << name;
|
||||||
convert_fn_ = backends[name_];
|
convert_fn_ = backends[name_];
|
||||||
is_switch_call_ = false;
|
|
||||||
is_multi_graph_sink_ = false;
|
|
||||||
simu_flag_ = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
||||||
|
|
|
@ -43,50 +43,19 @@ class Backend {
|
||||||
|
|
||||||
LinkFuncType convert_fn() { return convert_fn_; }
|
LinkFuncType convert_fn() { return convert_fn_; }
|
||||||
std::string name() { return name_; }
|
std::string name() { return name_; }
|
||||||
virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {}
|
|
||||||
virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; }
|
|
||||||
virtual bool GetCond(const BaseRef &c, bool *value);
|
virtual bool GetCond(const BaseRef &c, bool *value);
|
||||||
virtual bool GetIndex(const BaseRef &c, int *value);
|
virtual bool GetIndex(const BaseRef &c, int *value);
|
||||||
virtual void SetSwitchGraph() {}
|
|
||||||
virtual void SetSwitchActive(const BaseRef &, bool) {}
|
|
||||||
virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {}
|
|
||||||
virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {}
|
|
||||||
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
|
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
|
||||||
void set_curr_switch(const BaseRef &value) {
|
|
||||||
curr_switch_ = value;
|
|
||||||
is_switch_call_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
BaseRef curr_switch() { return curr_switch_; }
|
|
||||||
virtual void Link(GraphId) {}
|
virtual void Link(GraphId) {}
|
||||||
virtual LinConvertResult GetMultiGraphRun(const FuncGraphPtr &) { return LinConvertResult(); }
|
virtual void SetDebugger() {}
|
||||||
|
|
||||||
LinConvertResult multi_result() { return multi_result_; }
|
|
||||||
void set_multi_result(const LinConvertResult &value) { multi_result_ = value; }
|
|
||||||
AnfNodePtr final_output() const { return final_output_; }
|
|
||||||
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
|
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
|
||||||
void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
|
void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
|
||||||
bool simu_flag() const { return simu_flag_; }
|
|
||||||
bool is_switch_call() const { return is_switch_call_; }
|
|
||||||
void set_simu_flag(bool simu) { simu_flag_ = simu; }
|
|
||||||
|
|
||||||
virtual void SetDebugger() {}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::string name_;
|
std::string name_;
|
||||||
LinkFuncType convert_fn_;
|
LinkFuncType convert_fn_;
|
||||||
BaseRef curr_switch_; // curr switch node
|
|
||||||
bool is_multi_graph_sink_;
|
bool is_multi_graph_sink_;
|
||||||
bool is_switch_call_;
|
|
||||||
bool simu_flag_;
|
|
||||||
LinConvertResult multi_result_;
|
|
||||||
AnfNodePtr final_output_;
|
|
||||||
std::unordered_map<FuncGraphPtr, std::pair<FuncGraphPtr, AnfNodePtrList>> graph_user_inputs_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CondGraph {
|
|
||||||
bool curr_cond;
|
|
||||||
std::unordered_map<bool, GraphId> cond_graph_map;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class MsBackend : public Backend {
|
class MsBackend : public Backend {
|
||||||
|
@ -98,16 +67,7 @@ class MsBackend : public Backend {
|
||||||
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
|
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
|
||||||
|
|
||||||
VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args);
|
VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args);
|
||||||
void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override;
|
|
||||||
SwitchCondStatus SetSimuCond(const BaseRef &c, bool value) override;
|
|
||||||
|
|
||||||
void SetSwitchGraph() override;
|
|
||||||
void SetSwitchActive(const BaseRef &c, bool cond) override;
|
|
||||||
void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) override;
|
|
||||||
void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) override;
|
|
||||||
void Link(GraphId) override;
|
void Link(GraphId) override;
|
||||||
AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &);
|
|
||||||
LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override;
|
|
||||||
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
|
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
|
||||||
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
|
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
|
||||||
void CreateOtherSession(const std::string &target);
|
void CreateOtherSession(const std::string &target);
|
||||||
|
@ -121,9 +81,7 @@ class MsBackend : public Backend {
|
||||||
session::SessionPtr other_sess_;
|
session::SessionPtr other_sess_;
|
||||||
std::string target_device_;
|
std::string target_device_;
|
||||||
std::string other_device_;
|
std::string other_device_;
|
||||||
std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_;
|
|
||||||
std::unordered_map<GraphId, LinConvertResult> graph_id_map_;
|
std::unordered_map<GraphId, LinConvertResult> graph_id_map_;
|
||||||
std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_;
|
|
||||||
};
|
};
|
||||||
} // namespace compile
|
} // namespace compile
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -515,11 +515,7 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no
|
||||||
MS_LOG(DEBUG) << "LinConvert start";
|
MS_LOG(DEBUG) << "LinConvert start";
|
||||||
LinConvertResult result;
|
LinConvertResult result;
|
||||||
|
|
||||||
if (backend_->simu_flag()) {
|
result = lin_convert_(node_list, target);
|
||||||
result = backend_->GetMultiGraphRun(graph);
|
|
||||||
} else {
|
|
||||||
result = lin_convert_(node_list, target);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result.run == nullptr) {
|
if (result.run == nullptr) {
|
||||||
MS_LOG(ERROR) << "LinConvert failed";
|
MS_LOG(ERROR) << "LinConvert failed";
|
||||||
|
@ -546,27 +542,6 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
|
|
||||||
MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString();
|
|
||||||
if (backend_->is_multi_graph_sink()) {
|
|
||||||
VectorRef args;
|
|
||||||
args.emplace_back(-1);
|
|
||||||
MS_LOG(DEBUG) << "call::" << height_;
|
|
||||||
AddInst(Instruction::kCall, args);
|
|
||||||
|
|
||||||
args.clear();
|
|
||||||
args.emplace_back(node->input(1));
|
|
||||||
AddInst(Instruction::kSwitchReturn, args);
|
|
||||||
|
|
||||||
args.clear();
|
|
||||||
args.emplace_back(false);
|
|
||||||
args.emplace_back(Ref(node->input(1)));
|
|
||||||
args.emplace_back(Ref(node->input(2)));
|
|
||||||
args.emplace_back(Ref(node->input(3)));
|
|
||||||
AddInst(Instruction::kSwitch, args);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
|
int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
|
MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
|
||||||
|
@ -589,7 +564,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
|
||||||
AddPartial(node);
|
AddPartial(node);
|
||||||
} else if (IsPrimitive(fn, prim::kPrimSwitch)) {
|
} else if (IsPrimitive(fn, prim::kPrimSwitch)) {
|
||||||
AddSwitch(node);
|
AddSwitch(node);
|
||||||
AddSinkSwitch(node);
|
|
||||||
} else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
|
} else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
|
||||||
AddSwitchLayer(node);
|
AddSwitchLayer(node);
|
||||||
} else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
|
} else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
|
||||||
|
@ -607,14 +581,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) {
|
|
||||||
auto ret = LinConvert(graph, {});
|
|
||||||
if (ret == RET_FAILED) {
|
|
||||||
MS_LOG(EXCEPTION) << "MultiGraphRun failed.";
|
|
||||||
}
|
|
||||||
AddReturn(nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
|
bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
|
||||||
MS_LOG(DEBUG) << "Start split graph";
|
MS_LOG(DEBUG) << "Start split graph";
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
@ -659,11 +625,6 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) {
|
|
||||||
InstSet inst = Run(graph);
|
|
||||||
return inst;
|
|
||||||
}
|
|
||||||
|
|
||||||
InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
|
InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
|
||||||
|
@ -672,12 +633,8 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
|
||||||
int param_height = height_;
|
int param_height = height_;
|
||||||
MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
|
MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
|
||||||
|
|
||||||
if (backend_->simu_flag()) {
|
if (!SplitGraph(graph)) {
|
||||||
GenMultiGraphsRun(graph);
|
return inst_;
|
||||||
} else {
|
|
||||||
if (!SplitGraph(graph)) {
|
|
||||||
return inst_;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AddPadStack(param_height);
|
AddPadStack(param_height);
|
||||||
|
@ -712,12 +669,6 @@ void CompileGraph::AddPartial(const CNodePtr &node) {
|
||||||
if (!IsValueNode<FuncGraph>(fn)) {
|
if (!IsValueNode<FuncGraph>(fn)) {
|
||||||
MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
|
MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
|
||||||
}
|
}
|
||||||
if (backend_->is_multi_graph_sink()) {
|
|
||||||
auto func_graph = GetValueNode<FuncGraphPtr>(fn);
|
|
||||||
args.emplace_back(func_graph);
|
|
||||||
AnfNodePtrList outs(inputs.begin() + 2, inputs.end());
|
|
||||||
backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
|
|
||||||
}
|
|
||||||
for (size_t i = 1; i < inputs.size(); i++) {
|
for (size_t i = 1; i < inputs.size(); i++) {
|
||||||
args.emplace_back(Ref(inputs[i]));
|
args.emplace_back(Ref(inputs[i]));
|
||||||
}
|
}
|
||||||
|
@ -739,9 +690,6 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
|
||||||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
|
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
|
||||||
}
|
}
|
||||||
VectorRef args;
|
VectorRef args;
|
||||||
if (backend_->is_multi_graph_sink()) {
|
|
||||||
args.emplace_back(true);
|
|
||||||
}
|
|
||||||
args.emplace_back(Ref(inputs[1]));
|
args.emplace_back(Ref(inputs[1]));
|
||||||
args.emplace_back(Ref(inputs[2]));
|
args.emplace_back(Ref(inputs[2]));
|
||||||
args.emplace_back(Ref(inputs[3]));
|
args.emplace_back(Ref(inputs[3]));
|
||||||
|
@ -761,11 +709,7 @@ void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
|
||||||
|
|
||||||
void CompileGraph::AddReturn(const CNodePtr &node) {
|
void CompileGraph::AddReturn(const CNodePtr &node) {
|
||||||
VectorRef args;
|
VectorRef args;
|
||||||
if (backend_->simu_flag()) {
|
args.emplace_back(Ref(node->input(1)));
|
||||||
args.emplace_back(Ref(backend_->final_output()));
|
|
||||||
} else {
|
|
||||||
args.emplace_back(Ref(node->input(1)));
|
|
||||||
}
|
|
||||||
args.emplace_back(height_);
|
args.emplace_back(height_);
|
||||||
AddInst(Instruction::kReturn, args);
|
AddInst(Instruction::kReturn, args);
|
||||||
}
|
}
|
||||||
|
@ -783,11 +727,6 @@ void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim)
|
||||||
int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
|
int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
|
||||||
auto inputs = node->inputs();
|
auto inputs = node->inputs();
|
||||||
AnfNodePtr fn = inputs[0];
|
AnfNodePtr fn = inputs[0];
|
||||||
if (backend_->is_multi_graph_sink() && IsValueNode<FuncGraph>(fn)) {
|
|
||||||
auto func_graph = GetValueNode<FuncGraphPtr>(fn);
|
|
||||||
AnfNodePtrList outs(inputs.begin() + 1, inputs.end());
|
|
||||||
backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs);
|
|
||||||
}
|
|
||||||
(void)Ref(fn);
|
(void)Ref(fn);
|
||||||
size_t size = inputs.size();
|
size_t size = inputs.size();
|
||||||
for (size_t i = size - 1; i > 0; i--) {
|
for (size_t i = size - 1; i > 0; i--) {
|
||||||
|
@ -929,17 +868,6 @@ FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
|
FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
|
||||||
if (backend_->is_multi_graph_sink()) {
|
|
||||||
backend_->set_simu_flag(true);
|
|
||||||
MS_LOG(DEBUG) << "Start simulate";
|
|
||||||
backend_->SimulateRun(rt, graph);
|
|
||||||
MS_LOG(DEBUG) << "Link graphs";
|
|
||||||
insts_ = transform_->GenMultiGraphsSinkInst(graph);
|
|
||||||
rt->set_insts(insts_);
|
|
||||||
backend_->set_simu_flag(false);
|
|
||||||
MS_LOG(DEBUG) << "End start simulate";
|
|
||||||
backend_->Link(kInvalidGraphId);
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "End";
|
MS_LOG(DEBUG) << "End";
|
||||||
return rt;
|
return rt;
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,12 +54,10 @@ class CompileGraph {
|
||||||
~CompileGraph() = default;
|
~CompileGraph() = default;
|
||||||
|
|
||||||
InstSet Run(const FuncGraphPtr &func_graph);
|
InstSet Run(const FuncGraphPtr &func_graph);
|
||||||
InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph);
|
|
||||||
bool IsCut(const AnfNodePtr &node);
|
bool IsCut(const AnfNodePtr &node);
|
||||||
void Push(const AnfNodePtr &node);
|
void Push(const AnfNodePtr &node);
|
||||||
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
|
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
|
||||||
void Ret(int nargs);
|
void Ret(int nargs);
|
||||||
void GenMultiGraphsRun(const FuncGraphPtr &graph);
|
|
||||||
int Ref(const AnfNodePtr &node);
|
int Ref(const AnfNodePtr &node);
|
||||||
VectorRef SplitNodes(const FuncGraphPtr &func_graph);
|
VectorRef SplitNodes(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
|
@ -84,7 +82,6 @@ class CompileGraph {
|
||||||
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
|
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
|
||||||
int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
|
int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
|
||||||
int AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
|
int AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
|
||||||
void AddSinkSwitch(const CNodePtr &node);
|
|
||||||
void AddPadStack(int param_height);
|
void AddPadStack(int param_height);
|
||||||
void AddTailCall(const AnfNodePtr &fn, size_t size);
|
void AddTailCall(const AnfNodePtr &fn, size_t size);
|
||||||
void AddPartial(const CNodePtr &node);
|
void AddPartial(const CNodePtr &node);
|
||||||
|
|
|
@ -17,12 +17,9 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "vm/vm.h"
|
#include "vm/vm.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "vm/vmimpl.h"
|
#include "vm/vmimpl.h"
|
||||||
#include "vm/backend.h"
|
#include "vm/backend.h"
|
||||||
#include "vm/transform.h"
|
|
||||||
#include "pipeline/jit/parse/data_converter.h"
|
#include "pipeline/jit/parse/data_converter.h"
|
||||||
#include "utils/base_ref_extends.h"
|
#include "utils/base_ref_extends.h"
|
||||||
|
|
||||||
|
@ -142,33 +139,10 @@ void FinalVM::Popsp() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); }
|
|
||||||
|
|
||||||
bool FinalVM::PopStatus() {
|
|
||||||
if (ret_status_.empty()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool status = ret_status_.top();
|
|
||||||
ret_status_.pop();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
void FinalVM::DoJmp(const BaseRef &jmp_orig) {
|
void FinalVM::DoJmp(const BaseRef &jmp_orig) {
|
||||||
MS_LOG(DEBUG) << "Start";
|
MS_LOG(DEBUG) << "Start";
|
||||||
|
|
||||||
BaseRef jmp = jmp_orig;
|
BaseRef jmp = jmp_orig;
|
||||||
if (backend_->simu_flag()) {
|
|
||||||
bool is_switch_call = false;
|
|
||||||
if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base
|
|
||||||
MS_LOG(DEBUG) << "Start jump StructSwitch";
|
|
||||||
auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp);
|
|
||||||
jmp = simu_value->fn_;
|
|
||||||
backend_->set_curr_switch(simu_value->value_);
|
|
||||||
is_switch_call = true;
|
|
||||||
}
|
|
||||||
PushStatus(is_switch_call);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
|
if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
|
||||||
MS_LOG(DEBUG) << "Start jump StructPartial";
|
MS_LOG(DEBUG) << "Start jump StructPartial";
|
||||||
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
|
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
|
||||||
|
@ -270,13 +244,6 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
|
||||||
MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
|
MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rv = Ref(-1);
|
|
||||||
if (utils::isa<AnfNodePtr>(rv) || utils::isa<VectorRef>(rv)) {
|
|
||||||
auto &c = args[0];
|
|
||||||
cond_out_[c] = rv;
|
|
||||||
}
|
|
||||||
|
|
||||||
Pop(1);
|
Pop(1);
|
||||||
Popsp();
|
Popsp();
|
||||||
}
|
}
|
||||||
|
@ -294,51 +261,12 @@ void FinalVM::InstReturn(const VectorRef &args) {
|
||||||
int height = utils::cast<int>(args[1]);
|
int height = utils::cast<int>(args[1]);
|
||||||
|
|
||||||
auto rv = Ref(rpos);
|
auto rv = Ref(rpos);
|
||||||
if (backend_->simu_flag()) {
|
|
||||||
auto c = backend_->curr_switch();
|
|
||||||
auto status = PopStatus();
|
|
||||||
if (status) {
|
|
||||||
auto iter = cond_out_.find(c);
|
|
||||||
if (iter != cond_out_.end()) {
|
|
||||||
rv = MergeArgs(rv, iter->second);
|
|
||||||
cond_out_.erase(iter);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (backend_->is_switch_call()) {
|
|
||||||
backend_->SetSwitchGraph();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Pop(height);
|
Pop(height);
|
||||||
Push(rv);
|
Push(rv);
|
||||||
Popp();
|
Popp();
|
||||||
MS_LOG(DEBUG) << "End";
|
MS_LOG(DEBUG) << "End";
|
||||||
}
|
}
|
||||||
|
|
||||||
void FinalVM::InstSimuPartial(const VectorRef &args) {
|
|
||||||
const size_t args_size = 2;
|
|
||||||
if (args.size() < args_size) {
|
|
||||||
MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
|
|
||||||
<< args.size() << ".";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto &node = args[0];
|
|
||||||
if (!utils::isa<FuncGraphPtr>(node)) {
|
|
||||||
MS_LOG(ERROR) << "The type of 1st input of node must be FuncGraph";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto fg = utils::cast<FuncGraphPtr>(node);
|
|
||||||
int fn_ = utils::cast<int>(args[1]);
|
|
||||||
auto fn = utils::cast<int>(Ref(fn_));
|
|
||||||
MS_LOG(DEBUG) << "Partial argssize:" << args.size();
|
|
||||||
std::vector<BaseRef> outs(args.size() - 2);
|
|
||||||
(void)std::transform(args.begin() + 2, args.end(), outs.begin(),
|
|
||||||
[&, this](const BaseRef &a) { return Ref(utils::cast<int>(a)); });
|
|
||||||
Push(std::make_shared<StructPartial>(fn, VectorRef(outs), fg));
|
|
||||||
}
|
|
||||||
|
|
||||||
void FinalVM::InstRealPartial(const VectorRef &args) {
|
void FinalVM::InstRealPartial(const VectorRef &args) {
|
||||||
const size_t args_size = 1;
|
const size_t args_size = 1;
|
||||||
if (args.size() < args_size) {
|
if (args.size() < args_size) {
|
||||||
|
@ -358,91 +286,10 @@ void FinalVM::InstRealPartial(const VectorRef &args) {
|
||||||
|
|
||||||
void FinalVM::InstPartial(const VectorRef &args) {
|
void FinalVM::InstPartial(const VectorRef &args) {
|
||||||
MS_LOG(DEBUG) << "Start";
|
MS_LOG(DEBUG) << "Start";
|
||||||
if (backend_->is_multi_graph_sink()) {
|
InstRealPartial(args);
|
||||||
InstSimuPartial(args);
|
|
||||||
} else {
|
|
||||||
InstRealPartial(args);
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "End";
|
MS_LOG(DEBUG) << "End";
|
||||||
}
|
}
|
||||||
|
|
||||||
void FinalVM::InstSimuSwitch(const VectorRef &args) {
|
|
||||||
const size_t args_size = 4;
|
|
||||||
if (args.size() != args_size) {
|
|
||||||
MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
|
|
||||||
<< ".";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
bool cond = utils::cast<bool>(args[0]);
|
|
||||||
int cond_node = utils::cast<int>(args[1]);
|
|
||||||
int vtrue = utils::cast<int>(args[2]);
|
|
||||||
int vfalse = utils::cast<int>(args[3]);
|
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Simu switch cond:" << cond;
|
|
||||||
BaseRef c = Ref(cond_node);
|
|
||||||
bool bool_value = cond;
|
|
||||||
SwitchCondStatus cond_stat = backend_->SetSimuCond(c, bool_value);
|
|
||||||
|
|
||||||
if (cond_stat == kCondAlreadyRun) {
|
|
||||||
MS_LOG(DEBUG) << "switch alreay run bool while true jmp";
|
|
||||||
BaseRef jmp = Ref(vtrue);
|
|
||||||
if (utils::isa<StructPartial>(jmp)) {
|
|
||||||
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
|
|
||||||
backend_->RecallGraphInput(new_jmp->fg_, new_jmp->args_, c);
|
|
||||||
}
|
|
||||||
cond_jmp_[c] = Ref(vfalse);
|
|
||||||
Push(static_cast<int>(cond_stat));
|
|
||||||
Popp();
|
|
||||||
backend_->SetSwitchActive(c, bool_value);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (bool_value) {
|
|
||||||
Push(std::make_shared<StructSimuSwitch>(Ref(vtrue), c));
|
|
||||||
Pushsp();
|
|
||||||
} else {
|
|
||||||
MergeJmpArgs(Ref(vfalse), c);
|
|
||||||
Push(std::make_shared<StructSimuSwitch>(Ref(vfalse), c));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) {
|
|
||||||
auto iter = cond_jmp_.find(c);
|
|
||||||
if (iter == cond_jmp_.end()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto old_jmp = utils::cast<std::shared_ptr<StructPartial>>(iter->second);
|
|
||||||
auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
|
|
||||||
auto &old_args = old_jmp->args_;
|
|
||||||
auto &new_args = new_jmp->args_;
|
|
||||||
for (size_t i = 0; i < new_args.size(); ++i) {
|
|
||||||
auto &old_arg = old_args[i];
|
|
||||||
auto &new_arg = new_args[i];
|
|
||||||
new_arg = MergeArgs(old_arg, new_arg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) {
|
|
||||||
MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString();
|
|
||||||
if (utils::isa<VectorRef>(first)) {
|
|
||||||
auto old_vec_ref = utils::cast<VectorRef>(first);
|
|
||||||
if (utils::isa<VectorRef>(second)) {
|
|
||||||
auto new_vec_ref = utils::cast<VectorRef>(second);
|
|
||||||
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
|
|
||||||
} else {
|
|
||||||
old_vec_ref.push_back(second);
|
|
||||||
}
|
|
||||||
return old_vec_ref;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (utils::isa<VectorRef>(second)) {
|
|
||||||
auto new_vec_ref = utils::cast<VectorRef>(second);
|
|
||||||
new_vec_ref.push_back(first);
|
|
||||||
return new_vec_ref;
|
|
||||||
}
|
|
||||||
|
|
||||||
return VectorRef({first, second});
|
|
||||||
}
|
|
||||||
|
|
||||||
void FinalVM::InstRealSwitch(const VectorRef &args) {
|
void FinalVM::InstRealSwitch(const VectorRef &args) {
|
||||||
const size_t args_size = 3;
|
const size_t args_size = 3;
|
||||||
if (args.size() != args_size) {
|
if (args.size() != args_size) {
|
||||||
|
@ -472,11 +319,7 @@ void FinalVM::InstRealSwitch(const VectorRef &args) {
|
||||||
|
|
||||||
void FinalVM::InstSwitch(const VectorRef &args) {
|
void FinalVM::InstSwitch(const VectorRef &args) {
|
||||||
MS_LOG(DEBUG) << "Start";
|
MS_LOG(DEBUG) << "Start";
|
||||||
if (backend_->is_multi_graph_sink()) {
|
InstRealSwitch(args);
|
||||||
InstSimuSwitch(args);
|
|
||||||
} else {
|
|
||||||
InstRealSwitch(args);
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "End";
|
MS_LOG(DEBUG) << "End";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -580,14 +423,6 @@ void FinalVM::InstExternal(const VectorRef &args) {
|
||||||
VectorRef tuple;
|
VectorRef tuple;
|
||||||
RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]);
|
RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]);
|
||||||
compile::RunFuncPtr fn = run_ref.func_;
|
compile::RunFuncPtr fn = run_ref.func_;
|
||||||
if (backend_->simu_flag()) {
|
|
||||||
MS_LOG(DEBUG) << "Simu run";
|
|
||||||
if (args.size() == 1) {
|
|
||||||
MS_LOG(EXCEPTION) << "The number of args should be greater than 1, but got 1";
|
|
||||||
}
|
|
||||||
auto simu_run_ref = utils::cast<RunFunctionRef>(args[1]);
|
|
||||||
fn = simu_run_ref.func_;
|
|
||||||
}
|
|
||||||
for (size_t i = 2; i < args.size(); ++i) {
|
for (size_t i = 2; i < args.size(); ++i) {
|
||||||
auto index = utils::cast<int>(args[i]);
|
auto index = utils::cast<int>(args[i]);
|
||||||
tuple.push_back(Ref(index));
|
tuple.push_back(Ref(index));
|
||||||
|
|
|
@ -96,7 +96,6 @@ class FinalVM {
|
||||||
public:
|
public:
|
||||||
// Create a VM with the specified instructions and backend.
|
// Create a VM with the specified instructions and backend.
|
||||||
explicit FinalVM(const InstSet &insts, const BackendPtr &backend);
|
explicit FinalVM(const InstSet &insts, const BackendPtr &backend);
|
||||||
|
|
||||||
virtual ~FinalVM() = default;
|
virtual ~FinalVM() = default;
|
||||||
|
|
||||||
BaseRef Eval(const VectorRef &args);
|
BaseRef Eval(const VectorRef &args);
|
||||||
|
@ -104,10 +103,8 @@ class FinalVM {
|
||||||
void InstTailCall(const VectorRef &args);
|
void InstTailCall(const VectorRef &args);
|
||||||
void InstReturn(const VectorRef &args);
|
void InstReturn(const VectorRef &args);
|
||||||
void InstPartial(const VectorRef &args);
|
void InstPartial(const VectorRef &args);
|
||||||
void InstSimuPartial(const VectorRef &args);
|
|
||||||
void InstRealPartial(const VectorRef &args);
|
void InstRealPartial(const VectorRef &args);
|
||||||
void InstSwitch(const VectorRef &args);
|
void InstSwitch(const VectorRef &args);
|
||||||
void InstSimuSwitch(const VectorRef &args);
|
|
||||||
void InstRealSwitch(const VectorRef &args);
|
void InstRealSwitch(const VectorRef &args);
|
||||||
void InstTuple(const VectorRef &args);
|
void InstTuple(const VectorRef &args);
|
||||||
void InstPush(const VectorRef &args);
|
void InstPush(const VectorRef &args);
|
||||||
|
@ -129,23 +126,16 @@ class FinalVM {
|
||||||
void Popp();
|
void Popp();
|
||||||
void Pushsp();
|
void Pushsp();
|
||||||
void Popsp();
|
void Popsp();
|
||||||
void PushStatus(bool is_switch_call);
|
|
||||||
bool PopStatus();
|
|
||||||
void DoJmp(const BaseRef &jmp);
|
void DoJmp(const BaseRef &jmp);
|
||||||
void SyncData(const py::object &args);
|
void SyncData(const py::object &args);
|
||||||
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
|
|
||||||
BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
InstSet insts_;
|
InstSet insts_;
|
||||||
std::deque<BaseRef> insts_stack_;
|
std::deque<BaseRef> insts_stack_;
|
||||||
std::stack<int> retp_;
|
std::stack<int> retp_;
|
||||||
std::stack<int> retsp_;
|
std::stack<int> retsp_;
|
||||||
std::stack<bool> ret_status_;
|
|
||||||
int pc_;
|
int pc_;
|
||||||
int sp_;
|
int sp_;
|
||||||
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_;
|
|
||||||
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_;
|
|
||||||
BackendPtr backend_;
|
BackendPtr backend_;
|
||||||
const InstFunctionMap inst_function_map = {
|
const InstFunctionMap inst_function_map = {
|
||||||
{Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
|
{Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
|
||||||
|
|
Loading…
Reference in New Issue