From 463c74db0b3cf60daceb696a5cd553a147c9cc0c Mon Sep 17 00:00:00 2001 From: lichen Date: Wed, 7 Dec 2022 14:05:45 +0800 Subject: [PATCH] parallel_save_integral_opt_params --- .../parallel/graph_util/get_parallel_info.cc | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc index adfc07451e5..7377585de6b 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc @@ -27,6 +27,7 @@ #include "frontend/parallel/strategy.h" #include "frontend/parallel/tensor_layout/tensor_layout.h" #include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/parameter_manager.h" namespace mindspore { namespace parallel { @@ -233,10 +234,26 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) { std::vector graph_params = graph->parameters(); for (auto para : graph_params) { - std::string name = std::static_pointer_cast(para)->name(); + auto param_ptr = para->cast(); + MS_EXCEPTION_IF_NULL(param_ptr); + std::vector names = {param_ptr->name()}; + auto param_info = param_ptr->param_info(); + if (param_info) { + auto cloned_obj = GetPyParameterObj(param_info, CLONED_OBJ); + if (!py::isinstance(cloned_obj) && py::isinstance(cloned_obj)) { + auto obj_list = py::cast(cloned_obj); + for (size_t i = 0; i < obj_list.size(); ++i) { + auto each_obj = obj_list[i]; + if (py::hasattr(each_obj, "name")) { + auto name_obj = python_adapter::GetPyObjAttr(each_obj, "name"); + names.push_back(py::cast(name_obj)); + } + } + } + } auto tensor_layout = para->user_data(); if (tensor_layout == nullptr) { - MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; + MS_LOG(INFO) << "GetParameterLayout nullptr parameter: " << para->DebugString(); } else { const auto &device_arrangement = tensor_layout->device_arrangement().array(); const auto &tensor_map = tensor_layout->tensor_map().array(); @@ -246,8 +263,11 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) { const std::string &opt_shard_group = tensor_layout->opt_shard_group(); py::tuple layout = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group); - dict[py::str(name)] = layout; - MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); + for (auto &name : names) { + dict[py::str(name)] = layout; + } + MS_LOG(INFO) << "GetParameterLayout parameter: " << para->DebugString() << ", layout " + << tensor_layout->ToString(); } } return dict;