!46525 [AutoParallel]Parallel save integral opt params

Merge pull request !46525 from lichen/parallel_save_integral_opt_params
This commit is contained in:
i-robot 2022-12-08 01:48:27 +00:00 committed by Gitee
commit e7a1806613
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 24 additions and 4 deletions

View File

@ -27,6 +27,7 @@
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_layout.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/parameter_manager.h"
namespace mindspore {
namespace parallel {
@ -233,10 +234,26 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
std::vector<AnfNodePtr> graph_params = graph->parameters();
for (auto para : graph_params) {
std::string name = std::static_pointer_cast<Parameter>(para)->name();
auto param_ptr = para->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
std::vector<std::string> names = {param_ptr->name()};
auto param_info = param_ptr->param_info();
if (param_info) {
auto cloned_obj = GetPyParameterObj(param_info, CLONED_OBJ);
if (!py::isinstance<py::none>(cloned_obj) && py::isinstance<py::list>(cloned_obj)) {
auto obj_list = py::cast<py::list>(cloned_obj);
for (size_t i = 0; i < obj_list.size(); ++i) {
auto each_obj = obj_list[i];
if (py::hasattr(each_obj, "name")) {
auto name_obj = python_adapter::GetPyObjAttr(each_obj, "name");
names.push_back(py::cast<std::string>(name_obj));
}
}
}
}
auto tensor_layout = para->user_data<parallel::TensorLayout>();
if (tensor_layout == nullptr) {
MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
MS_LOG(INFO) << "GetParameterLayout nullptr parameter: " << para->DebugString();
} else {
const auto &device_arrangement = tensor_layout->device_arrangement().array();
const auto &tensor_map = tensor_layout->tensor_map().array();
@ -246,8 +263,11 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
const std::string &opt_shard_group = tensor_layout->opt_shard_group();
py::tuple layout =
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
dict[py::str(name)] = layout;
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
for (auto &name : names) {
dict[py::str(name)] = layout;
}
MS_LOG(INFO) << "GetParameterLayout parameter: " << para->DebugString() << ", layout "
<< tensor_layout->ToString();
}
}
return dict;