!47970 strategy_ckpt_viewable

Merge pull request !47970 from yao_yf/strategy_ckpt_viewable
This commit is contained in:
i-robot 2023-02-02 09:34:19 +00:00 committed by Gitee
commit 73a445c03a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 429 additions and 135 deletions

View File

@ -55,12 +55,15 @@ if(ENABLE_D OR ENABLE_GPU)
endif()
if(NOT BUILD_LITE)
list(APPEND MSLIB_SRC_DEPEND
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.cc)
list(APPEND MSLIB_SRC_DEPEND
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc)
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc)
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc)
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc)
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc)
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc)
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc)
endif()

View File

@ -7,6 +7,7 @@ endif()
if(ENABLE_DUMP_PROTO)
list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/strategy_checkpoint_info.cc")
endif()
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")

View File

@ -2239,7 +2239,7 @@ static void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const F
}
tensor_info_map[cloned_param_name] = cloned_param_layout;
}
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, manual_shape_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
}
}

View File

@ -1268,7 +1268,7 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_
auto input = node_inputs[LongToSize(i)];
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
if (input_parameter->has_default()) {
(void)param_names.emplace_back(std::make_pair(input_parameter->name(), input_parameter));
}
} else if (input->isa<CNode>()) {

View File

@ -19,17 +19,16 @@
#include <fstream>
#include <vector>
#include <utility>
#include "include/common/utils/utils.h"
#include "utils/ms_utils.h"
#include "include/common/utils/convert_utils.h"
#include "utils/log_adapter.h"
#include "include/common/debug/common.h"
#include "proto/node_strategy.pb.h"
#include "mindspore/core/utils/file_utils.h"
namespace mindspore {
namespace parallel {
const uint32_t JSON_SUFFIX_LENGTH = 5;
StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
static StrategyCheckpoint instance = StrategyCheckpoint();
if (ParallelContext::GetInstance() != nullptr) {
@ -39,6 +38,10 @@ StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
instance.group_info_save_file_ = ParallelContext::GetInstance()->group_ckpt_save_file();
instance.group_info_save_on_ = !ParallelContext::GetInstance()->group_ckpt_save_file().empty();
instance.load_format_json_ = instance.load_file_.size() >= JSON_SUFFIX_LENGTH &&
instance.load_file_.substr(instance.load_file_.size() - JSON_SUFFIX_LENGTH) == ".json";
instance.save_format_json_ = instance.save_file_.size() >= JSON_SUFFIX_LENGTH &&
instance.save_file_.substr(instance.save_file_.size() - JSON_SUFFIX_LENGTH) == ".json";
}
return instance;
}
@ -110,96 +113,47 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
if (!CheckPointExit(load_file_)) {
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
}
straspb::ParallelStrategyMap parallel_strategy_map;
std::fstream input(load_file_, std::ios::in | std::ios::binary);
if (!parallel_strategy_map.ParseFromIstream(&input)) {
MS_LOG(ERROR) << "Load strategy file failed";
return FAILED;
}
input.close();
size_t node_num = LongToSize(parallel_strategy_map.parallel_strategy_item_size());
for (size_t i = 0; i < node_num; i++) {
straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i));
std::string node_name = parallel_strategy_item.node_name();
straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
int64_t stage = SizeToLong(parallel_strategys.stage());
size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
Strategies strategy_inputs;
for (size_t j = 0; j < strategys_num; j++) {
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
Dimensions dimension;
size_t dim_num = LongToSize(parallel_strategy.dim_size());
for (size_t k = 0; k < dim_num; k++) {
dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
}
strategy_inputs.push_back(dimension);
if (load_format_json_) {
std::fstream input(load_file_, std::ios::in);
nlohmann::json stra_ckpt_info_j;
input >> stra_ckpt_info_j;
strategy_checkpoint_info_.from_json(stra_ckpt_info_j);
} else {
straspb::ParallelStrategyMap parallel_strategy_map;
std::fstream input(load_file_, std::ios::in | std::ios::binary);
if (!parallel_strategy_map.ParseFromIstream(&input)) {
MS_LOG(ERROR) << "Load strategy file failed";
return FAILED;
}
StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
(*strategy_map)[node_name] = strategy;
current_stage_ = SizeToLong(parallel_strategy_map.current_stage());
input.close();
strategy_checkpoint_info_.from_protobuf(parallel_strategy_map);
}
*strategy_map = strategy_checkpoint_info_.strategy_map();
current_stage_ = SizeToLong(strategy_checkpoint_info_.current_stage());
return SUCCESS;
}
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
ManualShapeMap *manual_shape_map) {
straspb::ParallelStrategyMap parallel_strategy_map;
parallel_strategy_map.set_current_stage(UlongToUint(LongToUlong(++current_stage_)));
for (auto &node_stra : strategy_map) {
straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
MS_EXCEPTION_IF_NULL(parallel_strategy_item);
parallel_strategy_item->set_node_name(node_stra.first);
straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
MS_EXCEPTION_IF_NULL(parallel_strategys);
MS_EXCEPTION_IF_NULL(node_stra.second);
parallel_strategys->set_stage(UlongToUint(LongToUlong(node_stra.second->GetInputStage())));
for (auto &dims : node_stra.second->GetInputDim()) {
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
MS_EXCEPTION_IF_NULL(parallel_strategy);
for (auto stra_dim : dims) {
parallel_strategy->add_dim(UlongToUint(LongToUlong(stra_dim)));
}
}
}
for (auto &node_tensor_info : tensor_info_map) {
TensorLayoutPtr tensor_layout = node_tensor_info.second;
MS_EXCEPTION_IF_NULL(tensor_layout);
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
MS_EXCEPTION_IF_NULL(parallel_layout_item);
parallel_layout_item->set_param_name(node_tensor_info.first);
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
MS_EXCEPTION_IF_NULL(dev_matrix);
for (auto dev_dim : tensor_layout->device_arrangement().array()) {
dev_matrix->add_dim(UlongToUint(LongToUlong(dev_dim)));
}
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
MS_EXCEPTION_IF_NULL(tensor_map);
for (auto map_dim : tensor_layout->tensor_map().array()) {
tensor_map->add_dim(LongToInt(map_dim));
}
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
MS_EXCEPTION_IF_NULL(manual_shape_map);
auto manual_shape = (*manual_shape_map)[node_tensor_info.first];
for (auto dim_pair : manual_shape) {
param_split_shape->add_dim(dim_pair.first);
indices_offset->add_dim(dim_pair.second);
}
parallel_layouts->set_field(LongToInt(tensor_layout->get_field_size()));
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
}
const ManualShapeMap &manual_shape_map) {
if (!CheckPath(save_file_)) {
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
}
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
if (!parallel_strategy_map.SerializeToOstream(&output)) {
MS_LOG(ERROR) << "Save strategy file failed";
return FAILED;
strategy_checkpoint_info_.Init(strategy_map, tensor_info_map, manual_shape_map, ++current_stage_);
if (save_format_json_) {
auto stra_ckpt_info_j = strategy_checkpoint_info_.to_json();
std::fstream output(save_file_, std::ios::out);
stra_ckpt_info_j >> output;
output.close();
} else {
auto parallel_strategy_map = strategy_checkpoint_info_.to_protobuf();
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
if (!parallel_strategy_map.SerializeToOstream(&output)) {
MS_LOG(ERROR) << "Save strategy file failed";
return FAILED;
}
output.close();
}
output.close();
ChangeFileMode(save_file_, S_IRUSR | S_IWUSR);
return SUCCESS;
}

View File

@ -22,20 +22,14 @@
#include <memory>
#include <utility>
#include "utils/hash_map.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/strategy.h"
#include "include/common/utils/parallel_context.h"
#include "frontend/parallel/tensor_layout/tensor_layout.h"
#include "frontend/parallel/tensor_layout/tensor_info.h"
#include "frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.h"
namespace mindspore {
namespace parallel {
using StrategyMap = mindspore::HashMap<std::string, StrategyPtr>;
using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
using TensorInfoMap = mindspore::HashMap<std::string, TensorLayoutPtr>;
using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>;
using ManualShapeMap = mindspore::HashMap<std::string, std::vector<std::pair<int64_t, int64_t>>>;
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
class StrategyCheckpoint {
public:
StrategyCheckpoint() {
@ -47,7 +41,8 @@ class StrategyCheckpoint {
Status Load(StrategyMap *strategy_map);
Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) const;
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
const ManualShapeMap &manual_shape_map);
Status SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list);
bool group_info_save_on() const { return group_info_save_on_; }
@ -65,6 +60,9 @@ class StrategyCheckpoint {
int64_t current_stage_ = 0;
std::string group_info_save_file_;
bool group_info_save_on_ = false;
bool load_format_json_ = true;
bool save_format_json_ = true;
StrategyCheckpointInfo strategy_checkpoint_info_;
};
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,179 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.h"
#include <fstream>
#include <vector>
#include <utility>
namespace mindspore {
namespace parallel {
void StrategyCheckpointInfo::set_strategy_map(const StrategyMap &strategy_map) { strategy_map_ = strategy_map; }
void StrategyCheckpointInfo::set_tensor_info_map(const TensorInfoMap &tensor_info_map) {
tensor_info_map_ = tensor_info_map;
}
void StrategyCheckpointInfo::set_manual_shape_map(const ManualShapeMap &manual_shape_map) {
manual_shape_map_ = manual_shape_map;
}
void StrategyCheckpointInfo::from_json(const nlohmann::json &stra_ckpt_info_j) {
current_stage_ = stra_ckpt_info_j.at("current_stage").get<int64_t>();
for (const auto &stra_j : stra_ckpt_info_j.at("parallel_strategy_item").items()) {
auto node_name = stra_j.key();
auto stage = stra_j.value().at("stage").get<int64_t>();
auto stra = stra_j.value().at("parallel_strategy").get<std::vector<std::vector<int64_t>>>();
strategy_map_[node_name] = std::make_shared<Strategy>(stage, stra);
}
for (const auto &layout_j : stra_ckpt_info_j.at("parallel_layout_item").items()) {
auto parameter_name = layout_j.key();
auto dev_matrix = layout_j.value().at("dev_matrix").get<std::vector<int64_t>>();
auto tensor_map = layout_j.value().at("tensor_map").get<std::vector<int64_t>>();
auto tensor_shape = layout_j.value().at("tensor_shape").get<std::vector<int64_t>>();
auto field = layout_j.value().at("field").get<int64_t>();
auto opt_weight_shard_step = layout_j.value().at("opt_weight_shard_step").get<int64_t>();
auto opt_weight_shard_size = layout_j.value().at("opt_weight_shard_size").get<int64_t>();
if (layout_j.value().contains("param_split_shape") && layout_j.value().contains("indices_offset")) {
auto param_split_shape = layout_j.value().at("param_split_shape").get<std::vector<int64_t>>();
auto indices_offset = layout_j.value().at("indices_offset").get<std::vector<int64_t>>();
if (param_split_shape.size() != indices_offset.size()) {
MS_LOG(EXCEPTION) << "For field_split strategy, the size of param_split_shape " << param_split_shape.size()
<< " is not equal to the size of indices_offset " << indices_offset.size();
}
for (size_t i = 0; i < param_split_shape.size(); ++i) {
manual_shape_map_[parameter_name].push_back({param_split_shape[i], indices_offset[i]});
}
}
tensor_info_map_[parameter_name] = std::make_shared<TensorLayout>();
tensor_info_map_[parameter_name]->InitFromVector(dev_matrix, tensor_map, tensor_shape);
tensor_info_map_[parameter_name]->set_opt_weight_shard_size(opt_weight_shard_size);
tensor_info_map_[parameter_name]->set_opt_weight_shard_step(opt_weight_shard_step);
tensor_info_map_[parameter_name]->set_field_size(field);
}
}
nlohmann::json StrategyCheckpointInfo::to_json() const {
nlohmann::json stra_ckpt_info_j;
stra_ckpt_info_j["current_stage"] = current_stage_;
for (const auto &stra_pair : strategy_map_) {
auto node_name = stra_pair.first;
auto node_stra = stra_pair.second;
nlohmann::json stra_j;
stra_j["stage"] = node_stra->GetInputStage();
stra_j["parallel_strategy"] = node_stra->GetInputDim();
stra_ckpt_info_j["parallel_strategy_item"][node_name] = stra_j;
}
for (const auto &layout_pair : tensor_info_map_) {
auto parameter_name = layout_pair.first;
auto layout = layout_pair.second;
nlohmann::json layout_j;
layout_j["dev_matrix"] = layout->device_arrangement().array();
layout_j["tensor_map"] = layout->tensor_map().array();
layout_j["tensor_shape"] = layout->tensor_shape().array();
layout_j["field"] = layout->get_field_size();
layout_j["opt_weight_shard_step"] = layout->opt_weight_shard_step();
layout_j["opt_weight_shard_size"] = layout->opt_weight_shard_size();
if (manual_shape_map_.find(parameter_name) != manual_shape_map_.end()) {
auto manual_shape = manual_shape_map_.at(parameter_name);
for (auto dim_pair : manual_shape) {
layout_j["param_split_shape"].push_back(dim_pair.first);
layout_j["indices_offset"].push_back(dim_pair.second);
}
}
stra_ckpt_info_j["parallel_layout_item"][parameter_name] = layout_j;
}
return stra_ckpt_info_j;
}
void StrategyCheckpointInfo::from_protobuf(const straspb::ParallelStrategyMap &parallel_strategy_map) {
size_t node_num = LongToSize(parallel_strategy_map.parallel_strategy_item_size());
for (size_t i = 0; i < node_num; i++) {
straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i));
std::string node_name = parallel_strategy_item.node_name();
straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
int64_t stage = SizeToLong(parallel_strategys.stage());
size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
Strategies strategy_inputs;
for (size_t j = 0; j < strategys_num; j++) {
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
Dimensions dimension;
size_t dim_num = LongToSize(parallel_strategy.dim_size());
for (size_t k = 0; k < dim_num; k++) {
dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
}
strategy_inputs.push_back(dimension);
}
StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
strategy_map_[node_name] = strategy;
current_stage_ = SizeToLong(parallel_strategy_map.current_stage());
}
}
straspb::ParallelStrategyMap StrategyCheckpointInfo::to_protobuf() const {
straspb::ParallelStrategyMap parallel_strategy_map;
parallel_strategy_map.set_current_stage(UlongToUint(LongToUlong(current_stage_)));
for (auto &node_stra : strategy_map_) {
straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
MS_EXCEPTION_IF_NULL(parallel_strategy_item);
parallel_strategy_item->set_node_name(node_stra.first);
straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
MS_EXCEPTION_IF_NULL(parallel_strategys);
MS_EXCEPTION_IF_NULL(node_stra.second);
parallel_strategys->set_stage(UlongToUint(LongToUlong(node_stra.second->GetInputStage())));
for (auto &dims : node_stra.second->GetInputDim()) {
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
MS_EXCEPTION_IF_NULL(parallel_strategy);
for (auto stra_dim : dims) {
parallel_strategy->add_dim(UlongToUint(LongToUlong(stra_dim)));
}
}
}
for (auto &node_tensor_info : tensor_info_map_) {
TensorLayoutPtr tensor_layout = node_tensor_info.second;
MS_EXCEPTION_IF_NULL(tensor_layout);
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
MS_EXCEPTION_IF_NULL(parallel_layout_item);
parallel_layout_item->set_param_name(node_tensor_info.first);
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
MS_EXCEPTION_IF_NULL(dev_matrix);
for (auto dev_dim : tensor_layout->device_arrangement().array()) {
dev_matrix->add_dim(UlongToUint(LongToUlong(dev_dim)));
}
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
MS_EXCEPTION_IF_NULL(tensor_map);
for (auto map_dim : tensor_layout->tensor_map().array()) {
tensor_map->add_dim(LongToInt(map_dim));
}
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
parallel_layouts->set_field(LongToInt(tensor_layout->get_field_size()));
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
if (manual_shape_map_.find(node_tensor_info.first) != manual_shape_map_.end()) {
auto manual_shape = manual_shape_map_.at(node_tensor_info.first);
for (auto dim_pair : manual_shape) {
param_split_shape->add_dim(dim_pair.first);
indices_offset->add_dim(dim_pair.second);
}
}
}
return parallel_strategy_map;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include "nlohmann/json.hpp"
#include "utils/hash_map.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_layout.h"
#include "frontend/parallel/tensor_layout/tensor_info.h"
#include "proto/node_strategy.pb.h"
namespace mindspore {
namespace parallel {
using StrategyMap = mindspore::HashMap<std::string, StrategyPtr>;
using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
using TensorInfoMap = mindspore::HashMap<std::string, TensorLayoutPtr>;
using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>;
using ManualShapeMap = mindspore::HashMap<std::string, std::vector<std::pair<int64_t, int64_t>>>;
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
class StrategyCheckpointInfo {
public:
StrategyCheckpointInfo() {}
~StrategyCheckpointInfo() = default;
void Init(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
const ManualShapeMap &manual_shape_map, int64_t current_stage) {
strategy_map_ = strategy_map;
tensor_info_map_ = tensor_info_map;
manual_shape_map_ = manual_shape_map;
current_stage_ = current_stage;
}
StrategyMap strategy_map() const { return strategy_map_; }
void set_strategy_map(const StrategyMap &strategy_map);
TensorInfoMap tensor_info_map() const { return tensor_info_map_; }
void set_tensor_info_map(const TensorInfoMap &tensor_info_map);
ManualShapeMap manual_shape_map() const { return manual_shape_map_; }
void set_manual_shape_map(const ManualShapeMap &manual_shape_map);
int64_t current_stage() const { return current_stage_; }
void from_json(const nlohmann::json &stra_ckpt_info_j);
nlohmann::json to_json() const;
void from_protobuf(const straspb::ParallelStrategyMap &parallel_strategy_map);
straspb::ParallelStrategyMap to_protobuf() const;
private:
int64_t current_stage_;
StrategyMap strategy_map_;
TensorInfoMap tensor_info_map_;
ManualShapeMap manual_shape_map_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_

View File

@ -61,11 +61,13 @@ set(MSLIB_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
if(ENABLE_D)
list(APPEND MSLIB_INFER_SRC
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.cc"
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc"
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc"
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc"
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc"
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc")
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc"
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc")
endif()
add_library(mindspore_infer_shared_lib_obj OBJECT ${MSLIB_INFER_SRC})

View File

@ -16,6 +16,7 @@
from __future__ import absolute_import
import os
import json
import numpy as np
import mindspore as ms
from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
@ -81,7 +82,7 @@ def _convert_to_layout(param_name, tensor_layout):
return strategy
def _load_strategy_file(strategy_filename):
def _check_strategy_file(strategy_filename):
"""load parallel strategy file"""
if not isinstance(strategy_filename, str):
raise TypeError(f"For 'build_searched_strategy', the argument 'strategy_filename' should be string, "
@ -94,18 +95,25 @@ def _load_strategy_file(strategy_filename):
if os.path.getsize(strategy_filename) == 0:
raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} should not "
f"be empty. Please check whether the 'strategy_filename' is correct.")
parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
def _load_protobuf_strategy(strategy_filename):
"""load strategy from protobuf file"""
parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
with open(strategy_filename, 'rb') as f:
pb_content = f.read()
parallel_strategy_map.ParseFromString(pb_content)
try:
parallel_strategy_map.ParseFromString(pb_content)
except BaseException as e:
raise TypeError("The strategy file type should be one of json or protobuf. "
"When the file name extension is not '.json', "
"the file is considered as a protobuf file.") from e
return parallel_strategy_map
def _build_searched_strategy(strategy_filename):
"""build searched strategy"""
parallel_strategy_map = _load_strategy_file(strategy_filename)
def _build_protobuf_strategy(strategy_filename):
"""build strategy from protobuf file"""
parallel_strategy_map = _load_protobuf_strategy(strategy_filename)
layout_items = parallel_strategy_map.parallel_layout_item
if not layout_items:
raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} has no sliced "
@ -116,10 +124,94 @@ def _build_searched_strategy(strategy_filename):
parameter_name = layout_item.param_name
layout = layout_item.parallel_layouts
strategy[parameter_name] = layout
return strategy
def _build_json_strategy(strategy_filename):
"""build strategy from json file"""
with open(strategy_filename, 'r') as f:
json_content = json.load(f)
layout_items = json_content.get("parallel_layout_item")
strategy = {}
for parameter_name, layout_item in layout_items.items():
layout = ms.train.node_strategy_pb2.ParallelLayouts()
layout.field = layout_item.get("field")
layout.opt_weight_shard_size = layout_item.get("opt_weight_shard_size")
layout.opt_weight_shard_step = layout_item.get("opt_weight_shard_step")
dev_matrix = layout.dev_matrix.add()
for item in layout_item.get("dev_matrix"):
dev_matrix.dim.append(item)
tensor_map = layout.tensor_map.add()
for item in layout_item.get("tensor_map"):
tensor_map.dim.append(item)
param_split_shape = layout.param_split_shape.add()
if "param_split_shape" in layout_item:
for item in layout_item.get("param_split_shape"):
param_split_shape.dim.append(item)
indices_offset = layout.indices_offset.add()
if "indices_offset" in layout_item:
for item in layout_item.get("indices_offset"):
indices_offset.dim.append(item)
strategy[parameter_name] = layout
return strategy
def _build_searched_strategy(strategy_filename):
"""build searched strategy"""
_check_strategy_file(strategy_filename)
if strategy_filename[-5:] != ".json":
return _build_protobuf_strategy(strategy_filename)
return _build_json_strategy(strategy_filename)
def _merge_protobuf_strategy(src_strategy_files, dst_strategy_file):
"""merge protobuf strategy"""
dst_parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
merged_stage = []
for src_strategy_file in src_strategy_files:
src_parallel_strategy_map = _load_protobuf_strategy(src_strategy_file)
strategy_items = src_parallel_strategy_map.parallel_strategy_item
layout_items = src_parallel_strategy_map.parallel_layout_item
if not strategy_items or not layout_items:
raise ValueError("The strategy file {} is empty".format(src_strategy_file))
pipeline_stage = strategy_items[0].parallel_strategys.stage
if pipeline_stage in merged_stage:
continue
for layout_item in layout_items:
layout_item.param_name = "-".join([str(pipeline_stage), layout_item.param_name])
dst_parallel_strategy_map.parallel_strategy_item.extend(strategy_items)
dst_parallel_strategy_map.parallel_layout_item.extend(layout_items)
merged_stage.append(pipeline_stage)
dst_parallel_strategy_map.current_stage = 1
with open(dst_strategy_file, "wb") as f:
f.write(dst_parallel_strategy_map.SerializeToString())
def _merge_json_strategy(src_strategy_files, dst_strategy_file):
"""merge protobuf strategy"""
dst_parallel_strategy_map = {"current_stage": 1, "parallel_strategy_item": {}, "parallel_layout_item": {}}
merged_stage = []
for src_strategy_file in src_strategy_files:
with open(src_strategy_file, 'r') as f:
json_content = json.load(f)
layout_items = json_content.get("parallel_layout_item")
strategy_items = json_content.get("parallel_strategy_item")
if not strategy_items or not layout_items:
raise ValueError("The strategy file {} is empty".format(src_strategy_file))
pipeline_stage = strategy_items.get(list(strategy_items.keys())[0]).get('stage')
if pipeline_stage in merged_stage:
continue
for param_name, layout_item in layout_items.items():
new_layout_item = {}
new_param_name = "-".join([str(pipeline_stage), param_name])
new_layout_item[new_param_name] = layout_item
dst_parallel_strategy_map.get("parallel_layout_item").update(new_layout_item)
dst_parallel_strategy_map.get("parallel_strategy_item").update(strategy_items)
merged_stage.append(pipeline_stage)
with open(dst_strategy_file, "w") as f:
json.dump(dst_parallel_strategy_map, f)
def _parameter_not_in_local_stage(param_name, origin_strategy_list, strategy_list):
"""parameter whether in the local stage"""
if origin_strategy_list is None or strategy_list is None:

View File

@ -22,8 +22,9 @@ from collections import defaultdict
import numpy as np
import mindspore as ms
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, _load_strategy_file, \
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
_merge_protobuf_strategy, _merge_json_strategy
__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
@ -55,26 +56,16 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
_make_dir(dst_strategy_dir, "path")
if not os.path.isdir(src_strategy_dirs):
raise NotADirectoryError("src_strategy_dirs {} is not a directory.".format(src_strategy_dirs))
src_strategy_files = os.path.join(src_strategy_dirs, "*.ckpt")
dst_parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
merged_stage = []
for src_strategy_file in glob.glob(src_strategy_files):
src_parallel_strategy_map = _load_strategy_file(src_strategy_file)
strategy_items = src_parallel_strategy_map.parallel_strategy_item
layout_items = src_parallel_strategy_map.parallel_layout_item
if not strategy_items or not layout_items:
raise ValueError("The strategy file {} is empty".format(src_strategy_file))
pipeline_stage = strategy_items[0].parallel_strategys.stage
if pipeline_stage in merged_stage:
continue
for layout_item in layout_items:
layout_item.param_name = "-".join([str(pipeline_stage), layout_item.param_name])
dst_parallel_strategy_map.parallel_strategy_item.extend(strategy_items)
dst_parallel_strategy_map.parallel_layout_item.extend(layout_items)
merged_stage.append(pipeline_stage)
dst_parallel_strategy_map.current_stage = 1
with open(dst_strategy_file, "wb") as f:
f.write(dst_parallel_strategy_map.SerializeToString())
src_strategy_files_protobuf = glob.glob(os.path.join(src_strategy_dirs, "*.ckpt"))
src_strategy_files_json = glob.glob(os.path.join(src_strategy_dirs, "*.json"))
if src_strategy_files_protobuf and src_strategy_files_json:
raise ValueError("The strategys format should be all '.ckpt' or all '.json'")
is_protobuf = len(src_strategy_files_protobuf) > 0
if is_protobuf:
_merge_protobuf_strategy(src_strategy_files_protobuf, dst_strategy_file)
else:
_merge_json_strategy(src_strategy_files_json, dst_strategy_file)
def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None):
@ -106,7 +97,7 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
>>> checkpoint_files_map = {}
>>> for rank in rank_list:
>>> checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
... checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
"""
if not isinstance(rank_id, int):
@ -167,14 +158,14 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
Examples:
>>> dst_device_num = 8
>>> for rank_id in range(dst_device_num)
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
>>> checkpoint_files_map = {}
>>> for rank in rank_list:
>>> checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
>>> save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
>>> transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
>>> "./src_strategy.ckpt", "./dst_strategy.ckpt")
>>> for rank_id in range(dst_device_num):
... rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
... checkpoint_files_map = {}
... for rank in rank_list:
... checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
... save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
... transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
... "./src_strategy.ckpt", "./dst_strategy.ckpt")
"""
if not isinstance(checkpoint_files_map, dict):

View File

@ -30,7 +30,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f
Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
ManualShapeMap *manual_shape_map) { return SUCCESS; }
const ManualShapeMap &manual_shape_map) { return SUCCESS; }
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file,
GroupInfoMap *group_info_map) const { return SUCCESS; }

View File

@ -201,7 +201,7 @@ def test_six_matmul_save_auto():
return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.json")
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel", dataset_strategy="full_batch")
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
@ -256,7 +256,7 @@ def six_matmul_load_auto():
return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt")
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.json")
strategy1 = ((2, 2), (2, 2))
strategy3 = ((2, 2), (2, 2))
strategy4 = ((2, 2), (2, 2))