forked from mindspore-Ecosystem/mindspore
inherit for dynamic param kernel.
This commit is contained in:
parent
8d46d9072b
commit
36d2cb0d88
|
@ -105,13 +105,20 @@ AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &i
|
|||
if (kernel::IsDynamicParamKernel(origin_prim->name())) {
|
||||
SetKernelInfoForDynamicParamKernel(new_cnode);
|
||||
} else if (IsPrimitiveEquals(new_prim, origin_prim)) {
|
||||
// If the primitive is not changed, inherit input and output from origin node.
|
||||
SetKernelInfoForNewCNode(new_cnode, false);
|
||||
// Reset output object type.
|
||||
new_kernel_builder->SetOutputsKernelObjectType(origin_kernel_build_info->GetAllOutputKernelObjectTypes());
|
||||
} else {
|
||||
SetKernelInfoForNewCNode(new_cnode, true);
|
||||
}
|
||||
|
||||
// If the primitive is not changed, this means only inputs are updated. So inherit output from origin node.
|
||||
if (IsPrimitiveEquals(new_prim, origin_prim)) {
|
||||
KernelBuildInfoPtr new_node_build_info = AnfAlgo::GetSelectKernelBuildInfo(new_cnode);
|
||||
KernelBuildInfoPtr origin_node_build_info = AnfAlgo::GetSelectKernelBuildInfo(origin_node);
|
||||
new_node_build_info->SetOutputsFormat(origin_node_build_info->GetAllOutputFormats());
|
||||
new_node_build_info->SetOutputsDeviceType(origin_node_build_info->GetAllOutputDeviceTypes());
|
||||
new_node_build_info->SetOutputsKernelObjectType(origin_node_build_info->GetAllOutputKernelObjectTypes());
|
||||
}
|
||||
|
||||
return new_cnode;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue