!18980 Add NewCNodeWithInfos for kernel graph

Merge pull request !18980 from LiangZhibo/master
This commit is contained in:
i-robot 2021-06-29 06:52:03 +00:00 committed by Gitee
commit cc8d760c8b
5 changed files with 17 additions and 2 deletions

View File

@ -93,6 +93,8 @@ CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfN
ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
cp_node->set_scope(scope);
cp_node->set_kernel_info(cnode->kernel_info_ptr());
cp_node->set_primal_attrs(cnode->primal_attrs());
cp_node->set_primal_debug_infos(cnode->primal_debug_infos());
(*node_map)[orig_node] = cp_node;
return cp_node->cast<CNodePtr>();
}

View File

@ -47,6 +47,8 @@ AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
ScopePtr scope = (anf_node->scope() != kDefaultScope) ? anf_node->scope() : kDefaultScope;
node->set_scope(scope);
node->set_kernel_info(cnode->kernel_info_ptr());
node->set_primal_attrs(cnode->primal_attrs());
node->set_primal_debug_infos(cnode->primal_debug_infos());
return node;
}

View File

@ -425,6 +425,16 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
return cnode;
}
CNodePtr KernelGraph::NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode) {
auto cnode = NewCNode(inputs);
if (ori_cnode != nullptr) {
cnode->set_attrs(ori_cnode->attrs());
cnode->set_primal_attrs(ori_cnode->primal_attrs());
cnode->set_primal_debug_infos(ori_cnode->primal_debug_infos());
}
return cnode;
}
void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -108,6 +108,7 @@ class KernelGraph : public FuncGraph {
void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter);
std::vector<AnfNodePtr> outputs() const;
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
CNodePtr NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode = nullptr);
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
CNodePtr NewCNode(const CNodePtr &cnode);
void ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const;

View File

@ -687,7 +687,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
GetCNodeInfo(cnode, &cnode_inputs);
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
return new_cnode;
}
@ -997,7 +997,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
// handle inputs of cnode except primitive
CreateCNodeInputs(cnode, graph, &cnode_inputs);
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
// if the cnode is call switch, remove call
if (new_cnode->inputs().size() > 1) {
auto first_input = new_cnode->input(kFirstDataInputIndex);