forked from mindspore-Ecosystem/mindspore
!8712 update pipeline parallel interface
From: @yangzhenzhang Reviewed-by: @kisnwang Signed-off-by:
This commit is contained in:
commit
cd6236c0a0
|
@ -63,8 +63,7 @@ void ParallelContext::Reset() {
|
||||||
all_reduce_fusion_split_indices_.clear();
|
all_reduce_fusion_split_indices_.clear();
|
||||||
all_reduce_fusion_split_sizes_.clear();
|
all_reduce_fusion_split_sizes_.clear();
|
||||||
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
|
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
|
||||||
stages_.clear();
|
pipeline_stage_split_num_ = 1;
|
||||||
pipeline_stage_split_num_ = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ParallelContext::set_device_num(int64_t device_num) {
|
void ParallelContext::set_device_num(int64_t device_num) {
|
||||||
|
@ -87,8 +86,6 @@ void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_rep
|
||||||
|
|
||||||
void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; }
|
void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; }
|
||||||
|
|
||||||
void ParallelContext::set_stage(const std::vector<int64_t> &stages) { stages_ = stages; }
|
|
||||||
|
|
||||||
bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) {
|
bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) {
|
||||||
auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
|
auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
|
||||||
if (iter == PARALLEL_MODE_LIST.end()) {
|
if (iter == PARALLEL_MODE_LIST.end()) {
|
||||||
|
|
|
@ -70,9 +70,6 @@ class ParallelContext {
|
||||||
void set_pipeline_stage_split_num(const int64_t stages);
|
void set_pipeline_stage_split_num(const int64_t stages);
|
||||||
int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
|
int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
|
||||||
|
|
||||||
void set_stage(const std::vector<int64_t> &stages);
|
|
||||||
std::vector<int64_t> stage() const { return stages_; }
|
|
||||||
|
|
||||||
void set_global_rank(int64_t global_rank);
|
void set_global_rank(int64_t global_rank);
|
||||||
int64_t global_rank() const { return global_rank_; }
|
int64_t global_rank() const { return global_rank_; }
|
||||||
|
|
||||||
|
@ -121,8 +118,7 @@ class ParallelContext {
|
||||||
int64_t global_rank_;
|
int64_t global_rank_;
|
||||||
std::string parallel_mode_;
|
std::string parallel_mode_;
|
||||||
std::string strategy_search_mode_;
|
std::string strategy_search_mode_;
|
||||||
std::vector<int64_t> stages_;
|
int64_t pipeline_stage_split_num_;
|
||||||
int64_t pipeline_stage_split_num_ = 0;
|
|
||||||
bool parameter_broadcast_;
|
bool parameter_broadcast_;
|
||||||
bool device_num_is_set_;
|
bool device_num_is_set_;
|
||||||
bool global_rank_is_set_;
|
bool global_rank_is_set_;
|
||||||
|
|
|
@ -54,44 +54,44 @@ bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &back
|
||||||
MS_LOG(ERROR) << "Invalid backend: " << backend;
|
MS_LOG(ERROR) << "Invalid backend: " << backend;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (stage.empty()) {
|
||||||
|
MS_LOG(ERROR) << "The size of stage must be positive";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
RankList devices, stage_map;
|
RankList devices, stage_map;
|
||||||
for (int64_t i = 0; i < device_num; ++i) {
|
for (int64_t i = 0; i < device_num; ++i) {
|
||||||
devices.push_back(i);
|
devices.push_back(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (stage.size()) {
|
int64_t summed_value = 0;
|
||||||
int64_t summed_value = 0;
|
for (auto begin = stage.begin(); begin != stage.end(); ++begin) {
|
||||||
for (auto begin = stage.begin(); begin != stage.end(); ++begin) {
|
if (*begin <= 0) {
|
||||||
if (*begin <= 0) {
|
MS_LOG(ERROR) << "The value in the pipeline stages should be positive value";
|
||||||
MS_LOG(ERROR) << "The value in the pipeline stages should be positive value";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
summed_value += *begin;
|
|
||||||
stage_map.push_back(*begin);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (summed_value != device_num) {
|
|
||||||
MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num "
|
|
||||||
<< device_num;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
summed_value += *begin;
|
||||||
stage_map.push_back(device_num);
|
stage_map.push_back(*begin);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &y : stage_map) {
|
if (summed_value != device_num) {
|
||||||
MS_LOG(DEBUG) << "Obtained stage id :" << y;
|
MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num "
|
||||||
|
<< device_num;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &ele : stage_map) {
|
||||||
|
MS_LOG(DEBUG) << "Obtained stage id: " << ele;
|
||||||
}
|
}
|
||||||
|
|
||||||
g_device_manager = std::make_shared<DeviceManager>();
|
g_device_manager = std::make_shared<DeviceManager>();
|
||||||
if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) {
|
if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) {
|
||||||
MS_LOG(INFO) << "Device initialization succeeds.";
|
MS_LOG(INFO) << "Device initialization succeeds.";
|
||||||
return true;
|
return true;
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Device initialization fails.";
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_LOG(ERROR) << "Device initialization fails.";
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckGlobalDeviceManager() {
|
void CheckGlobalDeviceManager() {
|
||||||
|
|
|
@ -1125,16 +1125,7 @@ OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAtt
|
||||||
StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
|
StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
|
||||||
ValueTuplePtr var = attrs[STRATEGY]->cast<ValueTuplePtr>();
|
ValueTuplePtr var = attrs[STRATEGY]->cast<ValueTuplePtr>();
|
||||||
StrategyPtr strategyPtr;
|
StrategyPtr strategyPtr;
|
||||||
std::vector<int64_t> stages = ParallelContext::GetInstance()->stage();
|
int64_t stage_id = g_device_manager->stage_id();
|
||||||
auto res = attrs.find(STAGE_ATTR);
|
|
||||||
int64_t stage_id = 0;
|
|
||||||
if (res != attrs.end()) {
|
|
||||||
stage_id = GetValue<int64_t>(res->second);
|
|
||||||
}
|
|
||||||
if (stage_id && stages.empty()) {
|
|
||||||
MS_LOG(ERROR) << "Find stage id:" << stage_id << " but the pipeline_stages is 0.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString();
|
MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString();
|
||||||
if (var == nullptr) {
|
if (var == nullptr) {
|
||||||
|
@ -1152,11 +1143,11 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
|
||||||
[](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
|
[](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
|
||||||
strategy.push_back(dim);
|
strategy.push_back(dim);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequence";
|
MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (strategy.empty()) {
|
if (strategy.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy";
|
MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
|
||||||
}
|
}
|
||||||
strategyPtr = NewStrategy(stage_id, strategy);
|
strategyPtr = NewStrategy(stage_id, strategy);
|
||||||
}
|
}
|
||||||
|
@ -1663,30 +1654,6 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
||||||
(void)prim->SetAttrs(attrs_temp);
|
(void)prim->SetAttrs(attrs_temp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// This function aims to check the valid rank and stage in the operations
|
|
||||||
// If the rank is not valid for the given stage, we chose not to init the strategy of the operation
|
|
||||||
// For example stage is [4, 4], and the group_list [[0,1,2,3],[4,5,6,7]]
|
|
||||||
// For stage 0, we require the rank_id is in [0,1,2,3]
|
|
||||||
Status ValidRankCheck(int32_t global_rank, int32_t strategy_stage) {
|
|
||||||
RankList local_group_list = g_device_manager->GetDeviceListByStageId(strategy_stage);
|
|
||||||
int32_t target = global_rank;
|
|
||||||
if (std::any_of(local_group_list.begin(), local_group_list.end(), [target](int32_t a) { return a == target; })) {
|
|
||||||
return Status::SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ValidStageCheck(const std::vector<int64_t> &stages, int32_t strategy_stage) {
|
|
||||||
if (stages.size() > 0) {
|
|
||||||
if (strategy_stage >= 0 && strategy_stage < (int32_t)stages.size()) {
|
|
||||||
return Status::SUCCESS;
|
|
||||||
}
|
|
||||||
return Status::FAILED;
|
|
||||||
} else {
|
|
||||||
return Status::SUCCESS;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// find previous parallel care node.
|
// find previous parallel care node.
|
||||||
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
|
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
|
||||||
|
@ -1781,9 +1748,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
||||||
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
||||||
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
||||||
}
|
}
|
||||||
// Get global rank after the checkpoint?
|
|
||||||
int64_t global_rank = ParallelContext::GetInstance()->global_rank();
|
|
||||||
std::vector<int64_t> stages = ParallelContext::GetInstance()->stage();
|
|
||||||
for (auto &node : all_nodes) {
|
for (auto &node : all_nodes) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||||
|
@ -1848,18 +1813,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
||||||
if (is_last_nodes && full_batch) {
|
if (is_last_nodes && full_batch) {
|
||||||
SetLastNodeStrategy(strategyPtr);
|
SetLastNodeStrategy(strategyPtr);
|
||||||
}
|
}
|
||||||
(*operator_).set_stage_id(strategyPtr->GetInputStage());
|
if (operator_->Init(strategyPtr) == FAILED) {
|
||||||
MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id();
|
|
||||||
if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) {
|
|
||||||
MS_LOG(ERROR) << "Find stage " << strategyPtr->GetInputStage() << " for operator " << prim->name()
|
|
||||||
<< " exceeds the global stage size " << stages.size() << '.';
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// If the strategy is not valid for the given global rank, then we skip the Init of the strategy
|
|
||||||
if (ValidRankCheck(global_rank, (*operator_).stage_id()) == FAILED) {
|
|
||||||
MS_LOG(INFO) << "Find global exceeds the range of the stage, skip the strategy init for operator "
|
|
||||||
<< prim->name();
|
|
||||||
} else if (operator_->Init(strategyPtr) == FAILED) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
|
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
|
||||||
}
|
}
|
||||||
cnode->set_user_data<OperatorInfo>(operator_);
|
cnode->set_user_data<OperatorInfo>(operator_);
|
||||||
|
@ -2800,7 +2754,6 @@ Status ParallelInit() {
|
||||||
int64_t device_num = ParallelContext::GetInstance()->device_num();
|
int64_t device_num = ParallelContext::GetInstance()->device_num();
|
||||||
int64_t global_rank = ParallelContext::GetInstance()->global_rank();
|
int64_t global_rank = ParallelContext::GetInstance()->global_rank();
|
||||||
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||||
std::vector<int64_t> stages = ParallelContext::GetInstance()->stage();
|
|
||||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
@ -2814,28 +2767,14 @@ Status ParallelInit() {
|
||||||
world_group = NCCL_WORLD_GROUP;
|
world_group = NCCL_WORLD_GROUP;
|
||||||
communication_backend = NCCL_BACKEND;
|
communication_backend = NCCL_BACKEND;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
|
MS_LOG(ERROR) << "Invalid communication backend: " << backend;
|
||||||
}
|
|
||||||
|
|
||||||
if (device_num <= 0) {
|
|
||||||
MS_LOG(ERROR) << "Invalid device num " << device_num << " , expected a positive device number";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (split_stage_num > 0) {
|
|
||||||
if (device_num % split_stage_num != 0) {
|
|
||||||
MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num
|
|
||||||
<< " , as we support only extract devision now";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < split_stage_num; i++) {
|
|
||||||
stages.push_back(device_num / split_stage_num);
|
|
||||||
}
|
|
||||||
} else if (split_stage_num < 0) {
|
|
||||||
MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << " , expected a positive stage number";
|
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
ParallelContext::GetInstance()->set_stage(stages);
|
if (split_stage_num <= 0) {
|
||||||
|
MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << ", expected a positive stage number";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t world_rank_size = 0;
|
uint32_t world_rank_size = 0;
|
||||||
if (!ParallelContext::GetInstance()->device_num_is_set()) {
|
if (!ParallelContext::GetInstance()->device_num_is_set()) {
|
||||||
|
@ -2855,7 +2794,28 @@ Status ParallelInit() {
|
||||||
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
|
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!stages.empty() && parallel_mode != SEMI_AUTO_PARALLEL) {
|
if ((device_num <= 0) || (device_num > MAX_DEVICE_NUM)) {
|
||||||
|
MS_LOG(ERROR) << "Invalid device num " << device_num;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the device_num maybe get from communication interface
|
||||||
|
if (device_num % split_stage_num != 0) {
|
||||||
|
MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((global_rank < 0) || (global_rank >= device_num)) {
|
||||||
|
MS_LOG(ERROR) << "Global rank " << global_rank << " is out of range, the device num is " << device_num;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> stages;
|
||||||
|
for (int i = 0; i < split_stage_num; i++) {
|
||||||
|
stages.push_back(device_num / split_stage_num);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((split_stage_num > 1) && (parallel_mode != SEMI_AUTO_PARALLEL)) {
|
||||||
MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL;
|
MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL;
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
|
@ -391,7 +391,7 @@ def set_auto_parallel_context(**kwargs):
|
||||||
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
|
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
|
||||||
the devices are distributed alone the pipeline. The total devices will be divided into
|
the devices are distributed alone the pipeline. The total devices will be divided into
|
||||||
'pipeline_stags' stages. This currently could only be used when
|
'pipeline_stags' stages. This currently could only be used when
|
||||||
parall mode semi_auto_parallel is enabled.
|
parallel mode semi_auto_parallel is enabled. Default: 1.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not attribute in auto parallel context.
|
ValueError: If input key is not attribute in auto parallel context.
|
||||||
|
@ -444,6 +444,7 @@ def reset_auto_parallel_context():
|
||||||
- strategy_ckpt_save_file: ''.
|
- strategy_ckpt_save_file: ''.
|
||||||
- full_batch: False.
|
- full_batch: False.
|
||||||
- enable_parallel_optimizer: False.
|
- enable_parallel_optimizer: False.
|
||||||
|
- pipeline_stages: 1.
|
||||||
"""
|
"""
|
||||||
_reset_auto_parallel_context()
|
_reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
|
@ -107,4 +107,4 @@ def test_reset_auto_parallel_context():
|
||||||
assert not parameter_broadcast
|
assert not parameter_broadcast
|
||||||
assert not device_num_is_set
|
assert not device_num_is_set
|
||||||
assert not parameter_broadcast_is_set
|
assert not parameter_broadcast_is_set
|
||||||
assert not stage
|
assert stage == 1
|
||||||
|
|
Loading…
Reference in New Issue