!8712 update pipeline parallel interface

From: @yangzhenzhang
Reviewed-by: @kisnwang
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-18 18:59:04 +08:00 committed by Gitee
commit cd6236c0a0
6 changed files with 58 additions and 104 deletions

View File

@ -63,8 +63,7 @@ void ParallelContext::Reset() {
all_reduce_fusion_split_indices_.clear(); all_reduce_fusion_split_indices_.clear();
all_reduce_fusion_split_sizes_.clear(); all_reduce_fusion_split_sizes_.clear();
strategy_search_mode_ = DYNAMIC_PROGRAMMING; strategy_search_mode_ = DYNAMIC_PROGRAMMING;
stages_.clear(); pipeline_stage_split_num_ = 1;
pipeline_stage_split_num_ = 0;
} }
void ParallelContext::set_device_num(int64_t device_num) { void ParallelContext::set_device_num(int64_t device_num) {
@ -87,8 +86,6 @@ void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_rep
void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; } void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; }
void ParallelContext::set_stage(const std::vector<int64_t> &stages) { stages_ = stages; }
bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) { bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) {
auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
if (iter == PARALLEL_MODE_LIST.end()) { if (iter == PARALLEL_MODE_LIST.end()) {

View File

@ -70,9 +70,6 @@ class ParallelContext {
void set_pipeline_stage_split_num(const int64_t stages); void set_pipeline_stage_split_num(const int64_t stages);
int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; } int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
void set_stage(const std::vector<int64_t> &stages);
std::vector<int64_t> stage() const { return stages_; }
void set_global_rank(int64_t global_rank); void set_global_rank(int64_t global_rank);
int64_t global_rank() const { return global_rank_; } int64_t global_rank() const { return global_rank_; }
@ -121,8 +118,7 @@ class ParallelContext {
int64_t global_rank_; int64_t global_rank_;
std::string parallel_mode_; std::string parallel_mode_;
std::string strategy_search_mode_; std::string strategy_search_mode_;
std::vector<int64_t> stages_; int64_t pipeline_stage_split_num_;
int64_t pipeline_stage_split_num_ = 0;
bool parameter_broadcast_; bool parameter_broadcast_;
bool device_num_is_set_; bool device_num_is_set_;
bool global_rank_is_set_; bool global_rank_is_set_;

View File

@ -54,44 +54,44 @@ bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &back
MS_LOG(ERROR) << "Invalid backend: " << backend; MS_LOG(ERROR) << "Invalid backend: " << backend;
return false; return false;
} }
if (stage.empty()) {
MS_LOG(ERROR) << "The size of stage must be positive";
return false;
}
RankList devices, stage_map; RankList devices, stage_map;
for (int64_t i = 0; i < device_num; ++i) { for (int64_t i = 0; i < device_num; ++i) {
devices.push_back(i); devices.push_back(i);
} }
if (stage.size()) { int64_t summed_value = 0;
int64_t summed_value = 0; for (auto begin = stage.begin(); begin != stage.end(); ++begin) {
for (auto begin = stage.begin(); begin != stage.end(); ++begin) { if (*begin <= 0) {
if (*begin <= 0) { MS_LOG(ERROR) << "The value in the pipeline stages should be positive value";
MS_LOG(ERROR) << "The value in the pipeline stages should be positive value";
return false;
}
summed_value += *begin;
stage_map.push_back(*begin);
}
if (summed_value != device_num) {
MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num "
<< device_num;
return false; return false;
} }
} else { summed_value += *begin;
stage_map.push_back(device_num); stage_map.push_back(*begin);
} }
for (auto &y : stage_map) { if (summed_value != device_num) {
MS_LOG(DEBUG) << "Obtained stage id :" << y; MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num "
<< device_num;
return false;
}
for (auto &ele : stage_map) {
MS_LOG(DEBUG) << "Obtained stage id: " << ele;
} }
g_device_manager = std::make_shared<DeviceManager>(); g_device_manager = std::make_shared<DeviceManager>();
if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) { if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) {
MS_LOG(INFO) << "Device initialization succeeds."; MS_LOG(INFO) << "Device initialization succeeds.";
return true; return true;
} else {
MS_LOG(ERROR) << "Device initialization fails.";
return false;
} }
MS_LOG(ERROR) << "Device initialization fails.";
return false;
} }
void CheckGlobalDeviceManager() { void CheckGlobalDeviceManager() {

View File

@ -1125,16 +1125,7 @@ OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAtt
StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) { StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
ValueTuplePtr var = attrs[STRATEGY]->cast<ValueTuplePtr>(); ValueTuplePtr var = attrs[STRATEGY]->cast<ValueTuplePtr>();
StrategyPtr strategyPtr; StrategyPtr strategyPtr;
std::vector<int64_t> stages = ParallelContext::GetInstance()->stage(); int64_t stage_id = g_device_manager->stage_id();
auto res = attrs.find(STAGE_ATTR);
int64_t stage_id = 0;
if (res != attrs.end()) {
stage_id = GetValue<int64_t>(res->second);
}
if (stage_id && stages.empty()) {
MS_LOG(ERROR) << "Find stage id:" << stage_id << " but the pipeline_stages is 0.";
return nullptr;
}
MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString(); MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString();
if (var == nullptr) { if (var == nullptr) {
@ -1152,11 +1143,11 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
[](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); }); [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
strategy.push_back(dim); strategy.push_back(dim);
} else { } else {
MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequence"; MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
} }
} }
if (strategy.empty()) { if (strategy.empty()) {
MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy"; MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
} }
strategyPtr = NewStrategy(stage_id, strategy); strategyPtr = NewStrategy(stage_id, strategy);
} }
@ -1663,30 +1654,6 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
(void)prim->SetAttrs(attrs_temp); (void)prim->SetAttrs(attrs_temp);
} }
} }
// This function aims to check the valid rank and stage in the operations
// If the rank is not valid for the given stage, we chose not to init the strategy of the operation
// For example stage is [4, 4], and the group_list [[0,1,2,3],[4,5,6,7]]
// For stage 0, we require the rank_id is in [0,1,2,3]
Status ValidRankCheck(int32_t global_rank, int32_t strategy_stage) {
RankList local_group_list = g_device_manager->GetDeviceListByStageId(strategy_stage);
int32_t target = global_rank;
if (std::any_of(local_group_list.begin(), local_group_list.end(), [target](int32_t a) { return a == target; })) {
return Status::SUCCESS;
}
return Status::FAILED;
}
Status ValidStageCheck(const std::vector<int64_t> &stages, int32_t strategy_stage) {
if (stages.size() > 0) {
if (strategy_stage >= 0 && strategy_stage < (int32_t)stages.size()) {
return Status::SUCCESS;
}
return Status::FAILED;
} else {
return Status::SUCCESS;
}
}
// find previous parallel care node. // find previous parallel care node.
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) { bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
@ -1781,9 +1748,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
} }
// Get global rank after the checkpoint?
int64_t global_rank = ParallelContext::GetInstance()->global_rank();
std::vector<int64_t> stages = ParallelContext::GetInstance()->stage();
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
@ -1848,18 +1813,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
if (is_last_nodes && full_batch) { if (is_last_nodes && full_batch) {
SetLastNodeStrategy(strategyPtr); SetLastNodeStrategy(strategyPtr);
} }
(*operator_).set_stage_id(strategyPtr->GetInputStage()); if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id();
if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) {
MS_LOG(ERROR) << "Find stage " << strategyPtr->GetInputStage() << " for operator " << prim->name()
<< " exceeds the global stage size " << stages.size() << '.';
return;
}
// If the strategy is not valid for the given global rank, then we skip the Init of the strategy
if (ValidRankCheck(global_rank, (*operator_).stage_id()) == FAILED) {
MS_LOG(INFO) << "Find global exceeds the range of the stage, skip the strategy init for operator "
<< prim->name();
} else if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
} }
cnode->set_user_data<OperatorInfo>(operator_); cnode->set_user_data<OperatorInfo>(operator_);
@ -2800,7 +2754,6 @@ Status ParallelInit() {
int64_t device_num = ParallelContext::GetInstance()->device_num(); int64_t device_num = ParallelContext::GetInstance()->device_num();
int64_t global_rank = ParallelContext::GetInstance()->global_rank(); int64_t global_rank = ParallelContext::GetInstance()->global_rank();
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
std::vector<int64_t> stages = ParallelContext::GetInstance()->stage();
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
@ -2814,28 +2767,14 @@ Status ParallelInit() {
world_group = NCCL_WORLD_GROUP; world_group = NCCL_WORLD_GROUP;
communication_backend = NCCL_BACKEND; communication_backend = NCCL_BACKEND;
} else { } else {
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; MS_LOG(ERROR) << "Invalid communication backend: " << backend;
}
if (device_num <= 0) {
MS_LOG(ERROR) << "Invalid device num " << device_num << " , expected a positive device number";
return FAILED;
}
if (split_stage_num > 0) {
if (device_num % split_stage_num != 0) {
MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num
<< " , as we support only extract devision now";
return FAILED;
}
for (int i = 0; i < split_stage_num; i++) {
stages.push_back(device_num / split_stage_num);
}
} else if (split_stage_num < 0) {
MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << " , expected a positive stage number";
return FAILED; return FAILED;
} }
ParallelContext::GetInstance()->set_stage(stages); if (split_stage_num <= 0) {
MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << ", expected a positive stage number";
return FAILED;
}
uint32_t world_rank_size = 0; uint32_t world_rank_size = 0;
if (!ParallelContext::GetInstance()->device_num_is_set()) { if (!ParallelContext::GetInstance()->device_num_is_set()) {
@ -2855,7 +2794,28 @@ Status ParallelInit() {
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
} }
if (!stages.empty() && parallel_mode != SEMI_AUTO_PARALLEL) { if ((device_num <= 0) || (device_num > MAX_DEVICE_NUM)) {
MS_LOG(ERROR) << "Invalid device num " << device_num;
return FAILED;
}
// the device_num maybe get from communication interface
if (device_num % split_stage_num != 0) {
MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num;
return FAILED;
}
if ((global_rank < 0) || (global_rank >= device_num)) {
MS_LOG(ERROR) << "Global rank " << global_rank << " is out of range, the device num is " << device_num;
return FAILED;
}
std::vector<int64_t> stages;
for (int i = 0; i < split_stage_num; i++) {
stages.push_back(device_num / split_stage_num);
}
if ((split_stage_num > 1) && (parallel_mode != SEMI_AUTO_PARALLEL)) {
MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL; MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL;
return FAILED; return FAILED;
} }

View File

@ -391,7 +391,7 @@ def set_auto_parallel_context(**kwargs):
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
the devices are distributed alone the pipeline. The total devices will be divided into the devices are distributed alone the pipeline. The total devices will be divided into
'pipeline_stags' stages. This currently could only be used when 'pipeline_stags' stages. This currently could only be used when
parall mode semi_auto_parallel is enabled. parallel mode semi_auto_parallel is enabled. Default: 1.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
@ -444,6 +444,7 @@ def reset_auto_parallel_context():
- strategy_ckpt_save_file: ''. - strategy_ckpt_save_file: ''.
- full_batch: False. - full_batch: False.
- enable_parallel_optimizer: False. - enable_parallel_optimizer: False.
- pipeline_stages: 1.
""" """
_reset_auto_parallel_context() _reset_auto_parallel_context()

View File

@ -107,4 +107,4 @@ def test_reset_auto_parallel_context():
assert not parameter_broadcast assert not parameter_broadcast
assert not device_num_is_set assert not device_num_is_set
assert not parameter_broadcast_is_set assert not parameter_broadcast_is_set
assert not stage assert stage == 1