forked from mindspore-Ecosystem/mindspore
find data parallel common group in auto parallel
This commit is contained in:
parent
04da5c2808
commit
501b978d16
|
@ -247,6 +247,32 @@ std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_na
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {
|
||||
std::string rank_list_name = FindRankListNameByHashName(hash_name);
|
||||
if (rank_list_name == "WORLD_GROUP") {
|
||||
int64_t device_num = g_device_manager->DeviceNum();
|
||||
RankList rank_list;
|
||||
for (size_t i = 0; i < size_t(device_num); ++i) {
|
||||
rank_list.push_back(i);
|
||||
}
|
||||
return rank_list;
|
||||
}
|
||||
RankList rank_list;
|
||||
std::string rank_str = "";
|
||||
for (size_t i = 0; i < rank_list_name.size(); i++) {
|
||||
if (rank_list_name[i] == '-') {
|
||||
int64_t rank_id = std::stoi(rank_str);
|
||||
rank_list.push_back(rank_id);
|
||||
rank_str = "";
|
||||
} else if (rank_list_name[i] <= '9' && rank_list_name[i] >= '0') {
|
||||
rank_str.push_back(rank_list_name[i]);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The rank list name cannot convert to rank list: " << rank_list_name;
|
||||
}
|
||||
}
|
||||
return rank_list;
|
||||
}
|
||||
|
||||
std::string HashName(const std::string &origin_name) { return std::to_string(std::hash<string>{}(origin_name)); }
|
||||
|
||||
// Group name is generated using the increasing ranks of the devices.
|
||||
|
|
|
@ -89,6 +89,7 @@ class DeviceManager {
|
|||
std::string world_group() const { return gm_.world_group(); }
|
||||
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); }
|
||||
std::string FindRankListNameByHashName(const std::string &hash_name);
|
||||
RankList FindRankListByHashName(const std::string &hash_name);
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Device>> devices_;
|
||||
|
|
|
@ -374,7 +374,7 @@ void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
|||
}
|
||||
}
|
||||
|
||||
static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
||||
bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
||||
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
return false;
|
||||
|
|
|
@ -45,6 +45,7 @@ void HandleFullySplitParameters(const FuncGraphPtr &root);
|
|||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
|
||||
void HandleAdaFactorOpt(const FuncGraphPtr &root);
|
||||
bool ParameterIsCloned(const AnfNodePtr ¶meter_node);
|
||||
bool IsFullySplitParameter(const ParameterPtr ¶m_ptr);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -3121,11 +3121,53 @@ bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
|
|||
current_stage == split_stage_num - 1);
|
||||
}
|
||||
|
||||
static void HandleGroupInfo() {
|
||||
RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
|
||||
auto parameters = root->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (IsFullySplitParameter(param_ptr)) {
|
||||
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
|
||||
<< " is fully shard, thus cannot find common data parallel group for this rank";
|
||||
return {};
|
||||
}
|
||||
}
|
||||
AnfNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<int64_t> common_group_list;
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
bool is_first_group = true;
|
||||
for (auto &node : all_nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
|
||||
continue;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (!prim->HasAttr(GROUP)) {
|
||||
MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
|
||||
}
|
||||
std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
|
||||
std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
|
||||
if (is_first_group) {
|
||||
common_group_list = group_list;
|
||||
is_first_group = false;
|
||||
} else {
|
||||
std::vector<int64_t> new_comm_group_list;
|
||||
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
|
||||
std::back_inserter(new_comm_group_list));
|
||||
common_group_list = new_comm_group_list;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
|
||||
return common_group_list;
|
||||
}
|
||||
|
||||
static void HandleGroupInfo(const FuncGraphPtr &root) {
|
||||
auto group_info = g_device_manager->group_info();
|
||||
if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
|
||||
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Save group info failed";
|
||||
if (StrategyCheckpoint::GetInstance().group_info_save_on()) {
|
||||
RankList comm_group = FindCommonMirrorGroup(root);
|
||||
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Save group info failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3239,7 +3281,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
|
||||
PipelinePostProcess(root, all_nodes);
|
||||
|
||||
HandleGroupInfo();
|
||||
HandleGroupInfo(root);
|
||||
|
||||
// handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
|
||||
HandleFullySplitParameters(root);
|
||||
|
|
|
@ -201,7 +201,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
|
||||
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list) {
|
||||
straspb::ParallelGroupMap parallel_group_map;
|
||||
for (auto &group : group_info_map) {
|
||||
straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item();
|
||||
|
@ -213,6 +213,10 @@ Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
|
|||
parallel_group_ranks->add_dim(rank);
|
||||
}
|
||||
}
|
||||
straspb::ParallelGroupRanks *ckpt_restore_rank_list = parallel_group_map.mutable_ckpt_restore_rank_list();
|
||||
for (auto &restore_rank : restore_rank_list) {
|
||||
ckpt_restore_rank_list->add_dim(restore_rank);
|
||||
}
|
||||
if (!CheckPath(group_info_save_file_)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ class StrategyCheckpoint {
|
|||
Status Load(StrategyMap *strategy_map);
|
||||
Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map);
|
||||
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
|
||||
Status SaveGroupInfo(const GroupInfoMap &group_info_map);
|
||||
Status SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list);
|
||||
bool group_info_save_on() const { return group_info_save_on_; }
|
||||
|
||||
static StrategyCheckpoint &GetInstance();
|
||||
|
|
|
@ -74,6 +74,7 @@ message ParallelGroupItem {
|
|||
|
||||
message ParallelGroupMap {
|
||||
repeated ParallelGroupItem parallel_group_item = 1;
|
||||
required ParallelGroupRanks ckpt_restore_rank_list = 2;
|
||||
}
|
||||
|
||||
message ParallelStrategyMap {
|
||||
|
|
|
@ -28,7 +28,7 @@ from threading import Thread, Lock
|
|||
import numpy as np
|
||||
from mindspore.train.checkpoint_pb2 import Checkpoint
|
||||
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
||||
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts
|
||||
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts, ParallelGroupMap
|
||||
from mindspore.train.print_pb2 import Print
|
||||
|
||||
import mindspore
|
||||
|
@ -1160,6 +1160,45 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|||
return merged_tensor
|
||||
|
||||
|
||||
def ckpt_restore_group_info(group_info_file_name):
|
||||
"""
|
||||
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
||||
that saves the group_info_file_name
|
||||
Args:
|
||||
group_info_file_name (str): Name of group information file.
|
||||
|
||||
Returns:
|
||||
List, the rank list.
|
||||
|
||||
Raises:
|
||||
ValueError: group information file is incorrect.
|
||||
TypeError: group_info_file_name is not str.
|
||||
|
||||
Examples:
|
||||
>>> restore_list = ckpt_restore_group_info("./group_info.ckpt")
|
||||
"""
|
||||
if not isinstance(group_info_file_name, str):
|
||||
raise TypeError(f"The group_info_file_name should be str, but got {type(group_info_file_name)}.")
|
||||
|
||||
if not os.path.isfile(group_info_file_name):
|
||||
raise ValueError(f"No such group info file: {group_info_file_name}.")
|
||||
|
||||
if os.path.getsize(group_info_file_name) == 0:
|
||||
raise ValueError("The group info file should not be empty.")
|
||||
|
||||
parallel_group_map = ParallelGroupMap()
|
||||
|
||||
with open(group_info_file_name, 'rb') as f:
|
||||
pb_content = f.read()
|
||||
parallel_group_map.ParseFromString(pb_content)
|
||||
|
||||
restore_list = parallel_group_map.ckpt_restore_rank_list
|
||||
if not restore_list:
|
||||
raise ValueError("The group info file has no restore rank list.")
|
||||
|
||||
restore_rank_list = [rank for rank in restore_list.dim]
|
||||
return restore_rank_list
|
||||
|
||||
def build_searched_strategy(strategy_filename):
|
||||
"""
|
||||
Build strategy of every parameter in network. Used in the case of distributed inference.
|
||||
|
|
|
@ -34,6 +34,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
|
||||
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { return SUCCESS; }
|
||||
|
||||
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { return SUCCESS; }
|
||||
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map,
|
||||
const RankList &restore_rank_list) { return SUCCESS; }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue