forked from mindspore-Ecosystem/mindspore
code check of strategy checkpoint
This commit is contained in:
parent
3f7e42e80c
commit
61d07f98a8
|
@ -22,6 +22,7 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "debug/common.h"
|
||||
#include "proto/node_strategy.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -39,6 +40,19 @@ StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
|
|||
return instance;
|
||||
}
|
||||
|
||||
bool StrategyCheckpoint::CheckPath(const std::string path) const {
|
||||
if (path.size() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "The checkpoit path " << path << " is too long";
|
||||
return false;
|
||||
}
|
||||
auto realpath = Common::GetRealPath(path);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << path;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
|
||||
std::ifstream fin(path);
|
||||
if (fin) {
|
||||
|
@ -49,6 +63,9 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
|
|||
|
||||
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) {
|
||||
MS_EXCEPTION_IF_NULL(group_info_map);
|
||||
if (!CheckPath(file)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
|
||||
}
|
||||
if (!CheckPointExit(file)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
|
||||
}
|
||||
|
@ -84,6 +101,9 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
|
|||
if (strategy_map == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
|
||||
}
|
||||
if (!CheckPath(load_file_)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
|
||||
}
|
||||
if (!CheckPointExit(load_file_)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
|
||||
}
|
||||
|
@ -167,7 +187,9 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
|
||||
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
|
||||
}
|
||||
|
||||
if (!CheckPath(save_file_)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
|
||||
}
|
||||
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
if (!parallel_strategy_map.SerializeToOstream(&output)) {
|
||||
MS_LOG(ERROR) << "Save strategy file failed";
|
||||
|
@ -190,13 +212,16 @@ Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
|
|||
parallel_group_ranks->add_dim(rank);
|
||||
}
|
||||
}
|
||||
|
||||
if (!CheckPath(group_info_save_file_)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
|
||||
}
|
||||
std::fstream output(group_info_save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
if (!parallel_group_map.SerializeToOstream(&output)) {
|
||||
MS_LOG(ERROR) << "Save strategy file failed";
|
||||
return FAILED;
|
||||
}
|
||||
output.close();
|
||||
ChangeFileMode(group_info_save_file_, S_IRUSR | S_IWUSR);
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
|
|
|
@ -65,6 +65,7 @@ class StrategyCheckpoint {
|
|||
bool load_checkpoint_on_;
|
||||
bool save_checkpoint_on_;
|
||||
bool CheckPointExit(const std::string path) const;
|
||||
bool CheckPath(const std::string path) const;
|
||||
int64_t current_stage_;
|
||||
std::string group_info_save_file_;
|
||||
bool group_info_save_on_;
|
||||
|
|
Loading…
Reference in New Issue