forked from mindspore-Ecosystem/mindspore
!46525 [AutoParallel]Parallel save integral opt params
Merge pull request !46525 from lichen/parallel_save_integral_opt_params
This commit is contained in:
commit
e7a1806613
|
@ -27,6 +27,7 @@
|
|||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_layout.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/parameter_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -233,10 +234,26 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
|
|||
std::vector<AnfNodePtr> graph_params = graph->parameters();
|
||||
|
||||
for (auto para : graph_params) {
|
||||
std::string name = std::static_pointer_cast<Parameter>(para)->name();
|
||||
auto param_ptr = para->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
std::vector<std::string> names = {param_ptr->name()};
|
||||
auto param_info = param_ptr->param_info();
|
||||
if (param_info) {
|
||||
auto cloned_obj = GetPyParameterObj(param_info, CLONED_OBJ);
|
||||
if (!py::isinstance<py::none>(cloned_obj) && py::isinstance<py::list>(cloned_obj)) {
|
||||
auto obj_list = py::cast<py::list>(cloned_obj);
|
||||
for (size_t i = 0; i < obj_list.size(); ++i) {
|
||||
auto each_obj = obj_list[i];
|
||||
if (py::hasattr(each_obj, "name")) {
|
||||
auto name_obj = python_adapter::GetPyObjAttr(each_obj, "name");
|
||||
names.push_back(py::cast<std::string>(name_obj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto tensor_layout = para->user_data<parallel::TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
|
||||
MS_LOG(INFO) << "GetParameterLayout nullptr parameter: " << para->DebugString();
|
||||
} else {
|
||||
const auto &device_arrangement = tensor_layout->device_arrangement().array();
|
||||
const auto &tensor_map = tensor_layout->tensor_map().array();
|
||||
|
@ -246,8 +263,11 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
|
|||
const std::string &opt_shard_group = tensor_layout->opt_shard_group();
|
||||
py::tuple layout =
|
||||
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
|
||||
for (auto &name : names) {
|
||||
dict[py::str(name)] = layout;
|
||||
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
|
||||
}
|
||||
MS_LOG(INFO) << "GetParameterLayout parameter: " << para->DebugString() << ", layout "
|
||||
<< tensor_layout->ToString();
|
||||
}
|
||||
}
|
||||
return dict;
|
||||
|
|
Loading…
Reference in New Issue