inherit for dynamic param kernel.

This commit is contained in:
ZPaC 2023-02-03 17:45:45 +08:00
parent 8d46d9072b
commit 36d2cb0d88
1 changed files with 10 additions and 3 deletions

View File

@ -105,13 +105,20 @@ AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &i
if (kernel::IsDynamicParamKernel(origin_prim->name())) {
SetKernelInfoForDynamicParamKernel(new_cnode);
} else if (IsPrimitiveEquals(new_prim, origin_prim)) {
// If the primitive is not changed, inherit input and output from origin node.
SetKernelInfoForNewCNode(new_cnode, false);
// Reset output object type.
new_kernel_builder->SetOutputsKernelObjectType(origin_kernel_build_info->GetAllOutputKernelObjectTypes());
} else {
SetKernelInfoForNewCNode(new_cnode, true);
}
// If the primitive is not changed, this means only inputs are updated. So inherit output from origin node.
if (IsPrimitiveEquals(new_prim, origin_prim)) {
KernelBuildInfoPtr new_node_build_info = AnfAlgo::GetSelectKernelBuildInfo(new_cnode);
KernelBuildInfoPtr origin_node_build_info = AnfAlgo::GetSelectKernelBuildInfo(origin_node);
new_node_build_info->SetOutputsFormat(origin_node_build_info->GetAllOutputFormats());
new_node_build_info->SetOutputsDeviceType(origin_node_build_info->GetAllOutputDeviceTypes());
new_node_build_info->SetOutputsKernelObjectType(origin_node_build_info->GetAllOutputKernelObjectTypes());
}
return new_cnode;
}