code check of strategy checkpoint

This commit is contained in:
yao_yf 2021-06-11 10:06:38 +08:00
parent 3f7e42e80c
commit 61d07f98a8
2 changed files with 28 additions and 2 deletions

View File

@ -22,6 +22,7 @@
#include "utils/ms_utils.h"
#include "utils/convert_utils.h"
#include "utils/log_adapter.h"
#include "debug/common.h"
#include "proto/node_strategy.pb.h"
namespace mindspore {
@ -39,6 +40,19 @@ StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
return instance;
}
bool StrategyCheckpoint::CheckPath(const std::string path) const {
if (path.size() > PATH_MAX) {
MS_LOG(ERROR) << "The checkpoit path " << path << " is too long";
return false;
}
auto realpath = Common::GetRealPath(path);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << path;
return false;
}
return true;
}
bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
std::ifstream fin(path);
if (fin) {
@ -49,6 +63,9 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) {
MS_EXCEPTION_IF_NULL(group_info_map);
if (!CheckPath(file)) {
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
}
if (!CheckPointExit(file)) {
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
}
@ -84,6 +101,9 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
if (strategy_map == nullptr) {
MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
}
if (!CheckPath(load_file_)) {
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
}
if (!CheckPointExit(load_file_)) {
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
}
@ -167,7 +187,9 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
}
if (!CheckPath(save_file_)) {
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
}
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
if (!parallel_strategy_map.SerializeToOstream(&output)) {
MS_LOG(ERROR) << "Save strategy file failed";
@ -190,13 +212,16 @@ Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
parallel_group_ranks->add_dim(rank);
}
}
if (!CheckPath(group_info_save_file_)) {
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
}
std::fstream output(group_info_save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
if (!parallel_group_map.SerializeToOstream(&output)) {
MS_LOG(ERROR) << "Save strategy file failed";
return FAILED;
}
output.close();
ChangeFileMode(group_info_save_file_, S_IRUSR | S_IWUSR);
return SUCCESS;
}
} // namespace parallel

View File

@ -65,6 +65,7 @@ class StrategyCheckpoint {
bool load_checkpoint_on_;
bool save_checkpoint_on_;
bool CheckPointExit(const std::string path) const;
bool CheckPath(const std::string path) const;
int64_t current_stage_;
std::string group_info_save_file_;
bool group_info_save_on_;