forked from mindspore-Ecosystem/mindspore
find data parallel common group in auto parallel
This commit is contained in:
parent
04da5c2808
commit
501b978d16
|
@ -247,6 +247,32 @@ std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_na
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {
|
||||||
|
std::string rank_list_name = FindRankListNameByHashName(hash_name);
|
||||||
|
if (rank_list_name == "WORLD_GROUP") {
|
||||||
|
int64_t device_num = g_device_manager->DeviceNum();
|
||||||
|
RankList rank_list;
|
||||||
|
for (size_t i = 0; i < size_t(device_num); ++i) {
|
||||||
|
rank_list.push_back(i);
|
||||||
|
}
|
||||||
|
return rank_list;
|
||||||
|
}
|
||||||
|
RankList rank_list;
|
||||||
|
std::string rank_str = "";
|
||||||
|
for (size_t i = 0; i < rank_list_name.size(); i++) {
|
||||||
|
if (rank_list_name[i] == '-') {
|
||||||
|
int64_t rank_id = std::stoi(rank_str);
|
||||||
|
rank_list.push_back(rank_id);
|
||||||
|
rank_str = "";
|
||||||
|
} else if (rank_list_name[i] <= '9' && rank_list_name[i] >= '0') {
|
||||||
|
rank_str.push_back(rank_list_name[i]);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The rank list name cannot convert to rank list: " << rank_list_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rank_list;
|
||||||
|
}
|
||||||
|
|
||||||
std::string HashName(const std::string &origin_name) { return std::to_string(std::hash<string>{}(origin_name)); }
|
std::string HashName(const std::string &origin_name) { return std::to_string(std::hash<string>{}(origin_name)); }
|
||||||
|
|
||||||
// Group name is generated using the increasing ranks of the devices.
|
// Group name is generated using the increasing ranks of the devices.
|
||||||
|
|
|
@ -89,6 +89,7 @@ class DeviceManager {
|
||||||
std::string world_group() const { return gm_.world_group(); }
|
std::string world_group() const { return gm_.world_group(); }
|
||||||
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); }
|
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); }
|
||||||
std::string FindRankListNameByHashName(const std::string &hash_name);
|
std::string FindRankListNameByHashName(const std::string &hash_name);
|
||||||
|
RankList FindRankListByHashName(const std::string &hash_name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::shared_ptr<Device>> devices_;
|
std::vector<std::shared_ptr<Device>> devices_;
|
||||||
|
|
|
@ -374,7 +374,7 @@ void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
||||||
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
|
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
|
||||||
if (tensor_layout == nullptr) {
|
if (tensor_layout == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -45,6 +45,7 @@ void HandleFullySplitParameters(const FuncGraphPtr &root);
|
||||||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
|
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
|
||||||
void HandleAdaFactorOpt(const FuncGraphPtr &root);
|
void HandleAdaFactorOpt(const FuncGraphPtr &root);
|
||||||
bool ParameterIsCloned(const AnfNodePtr ¶meter_node);
|
bool ParameterIsCloned(const AnfNodePtr ¶meter_node);
|
||||||
|
bool IsFullySplitParameter(const ParameterPtr ¶m_ptr);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -3121,11 +3121,53 @@ bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
|
||||||
current_stage == split_stage_num - 1);
|
current_stage == split_stage_num - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void HandleGroupInfo() {
|
RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
|
||||||
|
auto parameters = root->parameters();
|
||||||
|
for (auto ¶meter : parameters) {
|
||||||
|
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||||
|
if (IsFullySplitParameter(param_ptr)) {
|
||||||
|
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
|
||||||
|
<< " is fully shard, thus cannot find common data parallel group for this rank";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AnfNodePtr ret = root->get_return();
|
||||||
|
MS_EXCEPTION_IF_NULL(ret);
|
||||||
|
std::vector<int64_t> common_group_list;
|
||||||
|
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||||
|
bool is_first_group = true;
|
||||||
|
for (auto &node : all_nodes) {
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto prim = GetCNodePrimitive(node);
|
||||||
|
if (!prim->HasAttr(GROUP)) {
|
||||||
|
MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
|
||||||
|
}
|
||||||
|
std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
|
||||||
|
std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
|
||||||
|
if (is_first_group) {
|
||||||
|
common_group_list = group_list;
|
||||||
|
is_first_group = false;
|
||||||
|
} else {
|
||||||
|
std::vector<int64_t> new_comm_group_list;
|
||||||
|
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
|
||||||
|
std::back_inserter(new_comm_group_list));
|
||||||
|
common_group_list = new_comm_group_list;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
|
||||||
|
return common_group_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void HandleGroupInfo(const FuncGraphPtr &root) {
|
||||||
auto group_info = g_device_manager->group_info();
|
auto group_info = g_device_manager->group_info();
|
||||||
if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
|
if (StrategyCheckpoint::GetInstance().group_info_save_on()) {
|
||||||
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
|
RankList comm_group = FindCommonMirrorGroup(root);
|
||||||
MS_LOG(EXCEPTION) << "Save group info failed";
|
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << "Save group info failed";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3239,7 +3281,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
||||||
|
|
||||||
PipelinePostProcess(root, all_nodes);
|
PipelinePostProcess(root, all_nodes);
|
||||||
|
|
||||||
HandleGroupInfo();
|
HandleGroupInfo(root);
|
||||||
|
|
||||||
// handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
|
// handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
|
||||||
HandleFullySplitParameters(root);
|
HandleFullySplitParameters(root);
|
||||||
|
|
|
@ -201,7 +201,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
|
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list) {
|
||||||
straspb::ParallelGroupMap parallel_group_map;
|
straspb::ParallelGroupMap parallel_group_map;
|
||||||
for (auto &group : group_info_map) {
|
for (auto &group : group_info_map) {
|
||||||
straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item();
|
straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item();
|
||||||
|
@ -213,6 +213,10 @@ Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
|
||||||
parallel_group_ranks->add_dim(rank);
|
parallel_group_ranks->add_dim(rank);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
straspb::ParallelGroupRanks *ckpt_restore_rank_list = parallel_group_map.mutable_ckpt_restore_rank_list();
|
||||||
|
for (auto &restore_rank : restore_rank_list) {
|
||||||
|
ckpt_restore_rank_list->add_dim(restore_rank);
|
||||||
|
}
|
||||||
if (!CheckPath(group_info_save_file_)) {
|
if (!CheckPath(group_info_save_file_)) {
|
||||||
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
|
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,7 @@ class StrategyCheckpoint {
|
||||||
Status Load(StrategyMap *strategy_map);
|
Status Load(StrategyMap *strategy_map);
|
||||||
Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map);
|
Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map);
|
||||||
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
|
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
|
||||||
Status SaveGroupInfo(const GroupInfoMap &group_info_map);
|
Status SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list);
|
||||||
bool group_info_save_on() const { return group_info_save_on_; }
|
bool group_info_save_on() const { return group_info_save_on_; }
|
||||||
|
|
||||||
static StrategyCheckpoint &GetInstance();
|
static StrategyCheckpoint &GetInstance();
|
||||||
|
|
|
@ -74,6 +74,7 @@ message ParallelGroupItem {
|
||||||
|
|
||||||
message ParallelGroupMap {
|
message ParallelGroupMap {
|
||||||
repeated ParallelGroupItem parallel_group_item = 1;
|
repeated ParallelGroupItem parallel_group_item = 1;
|
||||||
|
required ParallelGroupRanks ckpt_restore_rank_list = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ParallelStrategyMap {
|
message ParallelStrategyMap {
|
||||||
|
|
|
@ -28,7 +28,7 @@ from threading import Thread, Lock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.train.checkpoint_pb2 import Checkpoint
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
||||||
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
||||||
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts
|
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts, ParallelGroupMap
|
||||||
from mindspore.train.print_pb2 import Print
|
from mindspore.train.print_pb2 import Print
|
||||||
|
|
||||||
import mindspore
|
import mindspore
|
||||||
|
@ -1160,6 +1160,45 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
||||||
return merged_tensor
|
return merged_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def ckpt_restore_group_info(group_info_file_name):
|
||||||
|
"""
|
||||||
|
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
||||||
|
that saves the group_info_file_name
|
||||||
|
Args:
|
||||||
|
group_info_file_name (str): Name of group information file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List, the rank list.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: group information file is incorrect.
|
||||||
|
TypeError: group_info_file_name is not str.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> restore_list = ckpt_restore_group_info("./group_info.ckpt")
|
||||||
|
"""
|
||||||
|
if not isinstance(group_info_file_name, str):
|
||||||
|
raise TypeError(f"The group_info_file_name should be str, but got {type(group_info_file_name)}.")
|
||||||
|
|
||||||
|
if not os.path.isfile(group_info_file_name):
|
||||||
|
raise ValueError(f"No such group info file: {group_info_file_name}.")
|
||||||
|
|
||||||
|
if os.path.getsize(group_info_file_name) == 0:
|
||||||
|
raise ValueError("The group info file should not be empty.")
|
||||||
|
|
||||||
|
parallel_group_map = ParallelGroupMap()
|
||||||
|
|
||||||
|
with open(group_info_file_name, 'rb') as f:
|
||||||
|
pb_content = f.read()
|
||||||
|
parallel_group_map.ParseFromString(pb_content)
|
||||||
|
|
||||||
|
restore_list = parallel_group_map.ckpt_restore_rank_list
|
||||||
|
if not restore_list:
|
||||||
|
raise ValueError("The group info file has no restore rank list.")
|
||||||
|
|
||||||
|
restore_rank_list = [rank for rank in restore_list.dim]
|
||||||
|
return restore_rank_list
|
||||||
|
|
||||||
def build_searched_strategy(strategy_filename):
|
def build_searched_strategy(strategy_filename):
|
||||||
"""
|
"""
|
||||||
Build strategy of every parameter in network. Used in the case of distributed inference.
|
Build strategy of every parameter in network. Used in the case of distributed inference.
|
||||||
|
|
|
@ -34,6 +34,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
||||||
|
|
||||||
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { return SUCCESS; }
|
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { return SUCCESS; }
|
||||||
|
|
||||||
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { return SUCCESS; }
|
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map,
|
||||||
|
const RankList &restore_rank_list) { return SUCCESS; }
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue