fix the sub parameters name empty for tuple parameter expand

This commit is contained in:
limingqi107 2023-02-23 16:29:40 +08:00
parent 52b51d3ea8
commit bfb2e09444
2 changed files with 11 additions and 3 deletions

View File

@ -415,12 +415,16 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract();
auto new_parameter = NewParameter(abstract);
MS_EXCEPTION_IF_NULL(new_parameter);
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
if (parameter != nullptr) {
new_parameter->set_name(parameter->name());
if (common::AnfAlgo::IsParameterWeight(parameter)) {
new_parameter->set_default_param(parameter->default_param());
}
} else {
// The created parameter name is empty, so set name to ensure that the parameter name is unique.
new_parameter->set_name(new_parameter->UniqueName());
}
// create kernel_info form new parameter
SetKernelInfoForNode(new_parameter);
@ -430,7 +434,10 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) {
ParameterPtr new_parameter = add_parameter();
MS_EXCEPTION_IF_NULL(new_parameter);
new_parameter->set_abstract(abstract);
// The created parameter name is empty, so set name to ensure that the parameter name is unique.
new_parameter->set_name(new_parameter->UniqueName());
// create kernel_info form new parameter
SetKernelInfoForNode(new_parameter);
AnfAlgo::SetGraphId(graph_id_, new_parameter.get());

View File

@ -197,9 +197,10 @@ void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(data_node.first);
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_node.first, data_node.second, false);
MS_EXCEPTION_IF_NULL(device_tensor);
ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node.first->DebugString()
<< "\tindex:" << data_node.second << "\tptr:" << device_tensor->GetPtr()
<< "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node.first->fullname_with_scope()
<< "\tdebug_name:" << data_node.first->DebugString() << "\tindex:" << data_node.second
<< "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
<< "\toriginal_ref_count:" << device_tensor->original_ref_count()
<< "\tdynamic_ref_count:" << device_tensor->dynamic_ref_count() << "\tflag:" << device_tensor->flag()
<< "\n ";
}