find data parallel common group in auto parallel

This commit is contained in:
yao_yf 2021-11-13 09:53:14 +08:00
parent 04da5c2808
commit 501b978d16
10 changed files with 125 additions and 10 deletions

View File

@ -247,6 +247,32 @@ std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_na
return iter->second;
}
RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {
std::string rank_list_name = FindRankListNameByHashName(hash_name);
if (rank_list_name == "WORLD_GROUP") {
int64_t device_num = g_device_manager->DeviceNum();
RankList rank_list;
for (size_t i = 0; i < size_t(device_num); ++i) {
rank_list.push_back(i);
}
return rank_list;
}
RankList rank_list;
std::string rank_str = "";
for (size_t i = 0; i < rank_list_name.size(); i++) {
if (rank_list_name[i] == '-') {
int64_t rank_id = std::stoi(rank_str);
rank_list.push_back(rank_id);
rank_str = "";
} else if (rank_list_name[i] <= '9' && rank_list_name[i] >= '0') {
rank_str.push_back(rank_list_name[i]);
} else {
MS_LOG(EXCEPTION) << "The rank list name cannot convert to rank list: " << rank_list_name;
}
}
return rank_list;
}
std::string HashName(const std::string &origin_name) { return std::to_string(std::hash<string>{}(origin_name)); }
// Group name is generated using the increasing ranks of the devices.

View File

@ -89,6 +89,7 @@ class DeviceManager {
std::string world_group() const { return gm_.world_group(); }
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); }
std::string FindRankListNameByHashName(const std::string &hash_name);
RankList FindRankListByHashName(const std::string &hash_name);
private:
std::vector<std::shared_ptr<Device>> devices_;

View File

@ -374,7 +374,7 @@ void HandleNoUsedParameter(const FuncGraphPtr &root) {
}
}
static bool IsFullySplitParameter(const ParameterPtr &param_ptr) {
bool IsFullySplitParameter(const ParameterPtr &param_ptr) {
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
if (tensor_layout == nullptr) {
return false;

View File

@ -45,6 +45,7 @@ void HandleFullySplitParameters(const FuncGraphPtr &root);
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
void HandleAdaFactorOpt(const FuncGraphPtr &root);
bool ParameterIsCloned(const AnfNodePtr &parameter_node);
bool IsFullySplitParameter(const ParameterPtr &param_ptr);
} // namespace parallel
} // namespace mindspore

View File

@ -3121,11 +3121,53 @@ bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
current_stage == split_stage_num - 1);
}
static void HandleGroupInfo() {
RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
auto parameters = root->parameters();
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (IsFullySplitParameter(param_ptr)) {
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
<< " is fully shard, thus cannot find common data parallel group for this rank";
return {};
}
}
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<int64_t> common_group_list;
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
bool is_first_group = true;
for (auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
continue;
}
auto prim = GetCNodePrimitive(node);
if (!prim->HasAttr(GROUP)) {
MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
}
std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
if (is_first_group) {
common_group_list = group_list;
is_first_group = false;
} else {
std::vector<int64_t> new_comm_group_list;
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
std::back_inserter(new_comm_group_list));
common_group_list = new_comm_group_list;
}
}
MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
return common_group_list;
}
static void HandleGroupInfo(const FuncGraphPtr &root) {
auto group_info = g_device_manager->group_info();
if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save group info failed";
if (StrategyCheckpoint::GetInstance().group_info_save_on()) {
RankList comm_group = FindCommonMirrorGroup(root);
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save group info failed";
}
}
}
@ -3239,7 +3281,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
PipelinePostProcess(root, all_nodes);
HandleGroupInfo();
HandleGroupInfo(root);
// handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
HandleFullySplitParameters(root);

View File

@ -201,7 +201,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
return SUCCESS;
}
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list) {
straspb::ParallelGroupMap parallel_group_map;
for (auto &group : group_info_map) {
straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item();
@ -213,6 +213,10 @@ Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
parallel_group_ranks->add_dim(rank);
}
}
straspb::ParallelGroupRanks *ckpt_restore_rank_list = parallel_group_map.mutable_ckpt_restore_rank_list();
for (auto &restore_rank : restore_rank_list) {
ckpt_restore_rank_list->add_dim(restore_rank);
}
if (!CheckPath(group_info_save_file_)) {
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
}

View File

@ -52,7 +52,7 @@ class StrategyCheckpoint {
Status Load(StrategyMap *strategy_map);
Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map);
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
Status SaveGroupInfo(const GroupInfoMap &group_info_map);
Status SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list);
bool group_info_save_on() const { return group_info_save_on_; }
static StrategyCheckpoint &GetInstance();

View File

@ -74,6 +74,7 @@ message ParallelGroupItem {
message ParallelGroupMap {
repeated ParallelGroupItem parallel_group_item = 1;
required ParallelGroupRanks ckpt_restore_rank_list = 2;
}
message ParallelStrategyMap {

View File

@ -28,7 +28,7 @@ from threading import Thread, Lock
import numpy as np
from mindspore.train.checkpoint_pb2 import Checkpoint
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts, ParallelGroupMap
from mindspore.train.print_pb2 import Print
import mindspore
@ -1160,6 +1160,45 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
return merged_tensor
def ckpt_restore_group_info(group_info_file_name):
"""
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
that saves the group_info_file_name
Args:
group_info_file_name (str): Name of group information file.
Returns:
List, the rank list.
Raises:
ValueError: group information file is incorrect.
TypeError: group_info_file_name is not str.
Examples:
>>> restore_list = ckpt_restore_group_info("./group_info.ckpt")
"""
if not isinstance(group_info_file_name, str):
raise TypeError(f"The group_info_file_name should be str, but got {type(group_info_file_name)}.")
if not os.path.isfile(group_info_file_name):
raise ValueError(f"No such group info file: {group_info_file_name}.")
if os.path.getsize(group_info_file_name) == 0:
raise ValueError("The group info file should not be empty.")
parallel_group_map = ParallelGroupMap()
with open(group_info_file_name, 'rb') as f:
pb_content = f.read()
parallel_group_map.ParseFromString(pb_content)
restore_list = parallel_group_map.ckpt_restore_rank_list
if not restore_list:
raise ValueError("The group info file has no restore rank list.")
restore_rank_list = [rank for rank in restore_list.dim]
return restore_rank_list
def build_searched_strategy(strategy_filename):
"""
Build strategy of every parameter in network. Used in the case of distributed inference.

View File

@ -34,6 +34,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { return SUCCESS; }
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { return SUCCESS; }
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map,
const RankList &restore_rank_list) { return SUCCESS; }
} // namespace parallel
} // namespace mindspore