forked from mindspore-Ecosystem/mindspore
!18980 Add NewCNodeWithInfos for kernel graph
Merge pull request !18980 from LiangZhibo/master
This commit is contained in:
commit
cc8d760c8b
|
@ -93,6 +93,8 @@ CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfN
|
|||
ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
|
||||
cp_node->set_scope(scope);
|
||||
cp_node->set_kernel_info(cnode->kernel_info_ptr());
|
||||
cp_node->set_primal_attrs(cnode->primal_attrs());
|
||||
cp_node->set_primal_debug_infos(cnode->primal_debug_infos());
|
||||
(*node_map)[orig_node] = cp_node;
|
||||
return cp_node->cast<CNodePtr>();
|
||||
}
|
||||
|
|
|
@ -47,6 +47,8 @@ AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
|
|||
ScopePtr scope = (anf_node->scope() != kDefaultScope) ? anf_node->scope() : kDefaultScope;
|
||||
node->set_scope(scope);
|
||||
node->set_kernel_info(cnode->kernel_info_ptr());
|
||||
node->set_primal_attrs(cnode->primal_attrs());
|
||||
node->set_primal_debug_infos(cnode->primal_debug_infos());
|
||||
return node;
|
||||
}
|
||||
|
||||
|
|
|
@ -425,6 +425,16 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
return cnode;
|
||||
}
|
||||
|
||||
CNodePtr KernelGraph::NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode) {
|
||||
auto cnode = NewCNode(inputs);
|
||||
if (ori_cnode != nullptr) {
|
||||
cnode->set_attrs(ori_cnode->attrs());
|
||||
cnode->set_primal_attrs(ori_cnode->primal_attrs());
|
||||
cnode->set_primal_debug_infos(ori_cnode->primal_debug_infos());
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
|
||||
void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
|
|
@ -108,6 +108,7 @@ class KernelGraph : public FuncGraph {
|
|||
void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter);
|
||||
std::vector<AnfNodePtr> outputs() const;
|
||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
|
||||
CNodePtr NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode = nullptr);
|
||||
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
|
||||
CNodePtr NewCNode(const CNodePtr &cnode);
|
||||
void ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const;
|
||||
|
|
|
@ -687,7 +687,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
|||
GetCNodeInfo(cnode, &cnode_inputs);
|
||||
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
|
||||
auto new_cnode = graph->NewCNode(cnode_inputs);
|
||||
auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
|
||||
return new_cnode;
|
||||
}
|
||||
|
||||
|
@ -997,7 +997,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|||
// handle inputs of cnode except primitive
|
||||
CreateCNodeInputs(cnode, graph, &cnode_inputs);
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
|
||||
auto new_cnode = graph->NewCNode(cnode_inputs);
|
||||
auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
|
||||
// if the cnode is call switch, remove call
|
||||
if (new_cnode->inputs().size() > 1) {
|
||||
auto first_input = new_cnode->input(kFirstDataInputIndex);
|
||||
|
|
Loading…
Reference in New Issue