From 2a752f24bf170b9b78fbf91d19c4f0ae8d8aed5b Mon Sep 17 00:00:00 2001 From: Ziyan Date: Mon, 22 Feb 2021 16:01:19 +0800 Subject: [PATCH] enable not fully use opt shard --- mindspore/ccsrc/frontend/parallel/context.cc | 10 ++ mindspore/ccsrc/frontend/parallel/context.h | 7 + .../parallel/ops_info/operator_info.cc | 70 +++++++++ .../parallel/ops_info/operator_info.h | 1 + .../ccsrc/frontend/parallel/step_parallel.cc | 141 ++++++++++-------- .../ccsrc/frontend/parallel/step_parallel.h | 3 +- .../parallel_strategy_checkpoint.cc | 12 +- .../parallel_strategy_checkpoint.h | 5 +- .../parallel/tensor_layout/tensor_layout.cc | 5 + .../parallel/tensor_layout/tensor_layout.h | 17 ++- mindspore/ccsrc/pipeline/jit/init.cc | 8 + mindspore/ccsrc/utils/node_strategy.proto | 2 + mindspore/core/utils/parallel_node_check.cc | 18 ++- mindspore/core/utils/parallel_node_check.h | 1 + mindspore/parallel/_auto_parallel_context.py | 61 +++++++- mindspore/parallel/_utils.py | 10 +- mindspore/train/serialization.py | 6 +- .../parallel/test_parallel_optimizer.py | 11 ++ .../test_set_auto_parallel_context.py | 6 + 19 files changed, 306 insertions(+), 88 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index bff42bce6e0..b82701c2112 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -69,6 +69,8 @@ void ParallelContext::Reset() { pipeline_stage_split_num_ = 1; grad_accumulation_step_ = 1; communi_parallel_mode_ = ALL_GROUP_PARALLEL; + optimizer_weight_shard_size_ = -1; + optimizer_weight_shard_integrated_save_ = false; } void ParallelContext::set_device_num(int64_t device_num) { @@ -132,6 +134,14 @@ void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_sav group_ckpt_save_file_ = group_ckpt_save_file; } +void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) { + optimizer_weight_shard_size_ = optimizer_weight_shard_size; +} + +void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save) { + optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save; +} + void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group) { all_reduce_fusion_split_indices_[group] = indices; } diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index e63d45835de..e1ce78e66d5 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -95,6 +95,11 @@ class ParallelContext { bool global_rank_is_set() const { return global_rank_is_set_; } bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } + void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size); + int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; } + void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save); + bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; } + void SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group); const std::vector GetAllReduceFusionSplitIndices(const std::string &group) const; void SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group); @@ -152,6 +157,8 @@ class ParallelContext { bool enable_parallel_optimizer_; bool init_param_shape_; std::string communi_parallel_mode_; + int64_t optimizer_weight_shard_size_; + bool optimizer_weight_shard_integrated_save_; }; } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index d02b1303e55..99c3c62543a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -473,6 +473,76 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector return SUCCESS; } +Status OperatorInfo::CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector *groups) { + if (groups == nullptr) { + MS_LOG(ERROR) << "The group is null. Operator is " << name_; + return FAILED; + } + CheckGlobalDeviceManager(); + int64_t rank = g_device_manager->global_rank(); + DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); + RankList group_devices; + Shape tensor_map = tensor_layout->origin_tensor_map().array(); + if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { + return FAILED; + } + + if (group_devices.size() == 1) { + MS_LOG(INFO) << "The dev size is 1, no need to create group."; + return SUCCESS; + } + int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size(); + if (optimizer_weight_shard_size != -1) { + // not fully use opt shard + int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin(); + int64_t repeated_size = group_devices.size(); + if (repeated_size % optimizer_weight_shard_size != 0) { + MS_LOG(WARNING) << "Parallel optimizer: optimizer_weight_shard_size " << optimizer_weight_shard_size + << " can not be applied. The repeated size of Operator " << name_ << " is " << repeated_size; + return FAILED; + } + repeated_size = repeated_size / optimizer_weight_shard_size; + // create allgather group + // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24] + RankList new_group_devices( + group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size, + group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size); + Group allgather_group = g_device_manager->CreateGroup(new_group_devices); + groups->push_back(allgather_group); + tensor_layout->set_opt_shard_group(allgather_group.name()); + MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name(); + // create mirror group + // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 16], [8, 24] + int64_t device_num = g_device_manager->stage_device_num(); + Shape dev_mat = {repeated_size, device_num / repeated_size}; + DeviceMatrix temp_dev_matrix(rank, stage_device_list_, dev_mat); + RankList mirror_group_devices; + if (temp_dev_matrix.GetDevicesAlongDim(0, &mirror_group_devices) != SUCCESS) { + return FAILED; + } + Group mirror_group = g_device_manager->CreateGroup(mirror_group_devices); + groups->push_back(mirror_group); + tensor_layout->set_opt_shard_mirror_group(mirror_group.name()); + MS_LOG(INFO) << "Parallel optimizer: create mirror group " << mirror_group.name(); + } else { + // fully use opt shard + // create allgather group + Group allgather_group = g_device_manager->CreateGroup(group_devices); + groups->push_back(allgather_group); + tensor_layout->set_opt_shard_group(allgather_group.name()); + MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name(); + } + // save in tensor_layout for strategy ckpt + auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_integrated_save(); + if (!integrated_save) { + tensor_layout->set_opt_weight_shard_size(optimizer_weight_shard_size); + int32_t opt_weight_shard_step = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1); + tensor_layout->set_opt_weight_shard_step(opt_weight_shard_step); + MS_LOG(INFO) << "Parallel optimizer: save opt_weight_shard_step " << opt_weight_shard_step << " in strategy ckpt"; + } + return SUCCESS; +} + Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { if (group == nullptr) { MS_LOG(ERROR) << "The group is null."; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index fa6822e5d2a..36be060ac66 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -177,6 +177,7 @@ class OperatorInfo { void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } int32_t stage_id() const { return stage_id_; } Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); + Status CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector *group); // Key for user data. constexpr static char key[] = "OpInfo"; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 3ad85726c7f..8ee113b81be 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -39,7 +39,6 @@ #include "frontend/parallel/graph_util/node_info.h" #include "frontend/parallel/node_check.h" #include "frontend/parallel/ops_info/matmul_info.h" -#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "ir/param_info.h" #include "ir/tensor.h" #include "utils/comm_manager.h" @@ -1069,7 +1068,7 @@ static std::pair FindParameterByValueNode(const AnfNodePtr &no << param_v.size(); } auto param_ptr = param_v[0]->user_data(); - if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { + if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) { return std::make_pair(nullptr, true); } return std::make_pair(node, true); @@ -1077,6 +1076,14 @@ static std::pair FindParameterByValueNode(const AnfNodePtr &no return std::make_pair(nullptr, false); } +static std::pair FindParameterByParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + auto param_ptr = node->user_data(); + if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) { + return std::make_pair(nullptr, false); + } + return std::make_pair(node, false); +} + // Only used for InsertMirrorOps std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { if (!node->isa() && !node->isa() && !node->isa()) { @@ -1084,11 +1091,7 @@ std::pair FindParameter(const AnfNodePtr &node, const FuncGrap } if (node->isa()) { - auto param_ptr = node->user_data(); - if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { - return std::make_pair(nullptr, false); - } - return std::make_pair(node, false); + return FindParameterByParameter(node, func_graph); } if (node->isa()) { @@ -1109,8 +1112,9 @@ std::pair FindParameter(const AnfNodePtr &node, const FuncGrap if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data()) { return std::make_pair(node, false); } - - if (IsParallelCareNode(cnode)) { + // When not fully use opt shard, allgather and mirror would be both inserted. + // Skip allgather here and find parameter recursively. + if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) { return std::make_pair(nullptr, false); } @@ -1238,10 +1242,17 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons auto param_ptr = param_node_pair.first->cast(); std::string param_name; - if (param_ptr != nullptr) { + if (param_ptr) { param_name = param_ptr->name(); + std::string opt_shard_mirror_group; + if (param_ptr->user_data()) { + opt_shard_mirror_group = param_ptr->user_data()->opt_shard_mirror_group(); + } + if (!opt_shard_mirror_group.empty()) { + // mirror ops is covered in not fully use opt shard case + backward_op = CreateMirrorOps(opt_shard_mirror_group, static_cast(opt_shard_mirror_group[0])); + } } - // not a RefKey if (!param_node_pair.second) { int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); @@ -1275,26 +1286,23 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons } std::string instance_name = MIRROR_OP; CNodePtr cnode = node->input(index)->cast(); + auto op = backward_op[0]; if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) { - for (auto &op : backward_op) { - // insert new node before the node - MS_EXCEPTION_IF_NULL(cnode); - AnfNodePtr pre_node = cnode->input(1); - InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); - auto comm_op = cnode->input(size_t(1))->cast(); - // add fusion flag - AddCommOpFusionType(comm_op, param_node_pair.first); - } + // insert new node before the node + MS_EXCEPTION_IF_NULL(cnode); + AnfNodePtr pre_node = cnode->input(1); + InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); + auto comm_op = cnode->input(size_t(1))->cast(); + // add fusion flag + AddCommOpFusionType(comm_op, param_node_pair.first); continue; } - for (auto &op : backward_op) { - AnfNodePtr pre_node = node->input(index); - InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); - auto comm_op = node->input(index)->cast(); - // add fusion flag - // pipeline mirror would not be set, which should be supported later - AddCommOpFusionType(comm_op, param_node_pair.first); - } + AnfNodePtr pre_node = node->input(index); + InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); + auto comm_op = node->input(index)->cast(); + // add fusion flag + // pipeline mirror would not be set, which should be supported later + AddCommOpFusionType(comm_op, param_node_pair.first); } } @@ -1695,7 +1703,11 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second); MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << GetPrimName(cnode); - continue; + } else { + // insert allgather operator between shard parameter and cnode + InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name); + MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " + << GetPrimName(cnode); } } else { // insert allgather operator between shard parameter and cnode @@ -1708,6 +1720,35 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & } } +static std::string GetOptShardGroup(const AnfNodePtr ¶meter, TensorLayout *const tensor_layout, + const OperatorInfoPtr &distribute_operator) { + std::string opt_shard_group; + if (!ParameterRequireGrad(parameter)) { + // only trainable parameters need parallel optimizer + MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; + } else if (parameter->cast()->param_info() && + !parameter->cast()->param_info()->parallel_optimizer()) { + MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard."; + } else if (tensor_layout->GenerateOptShardSliceShape() == Status::SUCCESS) { + // get the shard tensor slice shape if the weight is repeated on devices + // and the shape of the first dimension could be divided + // apply parallel optimizer on parameters + // create communication group for allgather operator + std::vector dev_group; + if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS && + !dev_group.empty()) { + opt_shard_group = dev_group[0].name(); + MS_LOG(INFO) << "Parallel optimizer: create group for " << parameter->ToString() << " success."; + } else { + MS_LOG(ERROR) << "Parallel optimizer: create group for " << parameter->ToString() << " failed."; + } + } else { + MS_LOG(WARNING) << "Parallel optimizer: " << parameter->ToString() << "'s distributed shape " + << tensor_layout->slice_shape().ToString() << " does not satisfy the conditions."; + } + return opt_shard_group; +} + // When this function returns non-empty string, that means parallel optimizer is applied on this parameter. std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { MS_EXCEPTION_IF_NULL(parameter); @@ -1731,33 +1772,10 @@ std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pairenable_parallel_optimizer(); if (enable_parallel_optimizer) { - if (!ParameterRequireGrad(parameter)) { - // only trainable parameters need parallel optimizer - MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; - } else if (parameter->cast()->param_info() && - !parameter->cast()->param_info()->parallel_optimizer()) { - MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard."; - } else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { - // get a totally shard tensor slice shape if the weight is repeated on devices - // and the shape of the first dimension could be divided - // apply parallel optimizer on parameters - // create communication group for allgather operator - slice_shape = tensor_layout.opt_shard_slice_shape(); - std::vector dev_group; - if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) == - Status::SUCCESS && - !dev_group.empty()) { - opt_shard_group = dev_group[0].name(); - // set communication group in tensor layout for checkpoint saving - tensor_layout.set_opt_shard_group(opt_shard_group); - MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString() - << " success."; - } else { - MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed."; - } - } else { - MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions."; - } + opt_shard_group = GetOptShardGroup(parameter, &tensor_layout, distribute_operator); + } + if (!opt_shard_group.empty()) { + slice_shape = tensor_layout.opt_shard_slice_shape(); } MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " << MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name(); @@ -2812,21 +2830,21 @@ bool IsCohesiveNode(const CNodePtr &cnode) { IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); } -std::vector> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) { +ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) { if (curr_depth > MAX_RECURSIVE_DEPTH) { MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: " << MAX_RECURSIVE_DEPTH; return {}; } std::vector node_inputs{node->inputs()}; - std::vector> param_names; + ParameterMap param_names; for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { int64_t idx = index > i ? index : i; auto input = node_inputs[i]; if (input->isa()) { auto input_parameter = input->cast(); if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) { - param_names.push_back({input_parameter->name(), idx}); + param_names.push_back({input_parameter->name(), input_parameter}); } } else if (input->isa()) { CNodePtr cnode = input->cast(); @@ -2878,10 +2896,7 @@ void CheckpointStrategy(const std::vector &all_nodes) { std::string stratey_key_name = prim->name() + "_" + param_name; stra_map[stratey_key_name] = operator_info->strategy(); for (auto param_name_pair : param_names) { - if (param_name_pair.second - 1 >= UlongToLong(input_tensor_info.size())) { - continue; - } - tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1]; + tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data(); } if (IsGatherPInfo(operator_info->name())) { auto gatherv2_info = std::dynamic_pointer_cast(operator_info); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 34ae08b6c4b..5e15d6b562f 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -32,6 +32,7 @@ #include "pipeline/jit/pipeline.h" #include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" using OperatorInfoPtr = std::shared_ptr; @@ -139,7 +140,7 @@ bool IsLastStage(); void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); -std::vector> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth); +ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth); void CheckpointStrategy(const std::vector &all_nodes); diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index 64e8614a0b8..929c8774e64 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -17,7 +17,6 @@ #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include -#include #include #include "utils/ms_utils.h" @@ -141,20 +140,19 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf } } for (auto &node_tensor_info : tensor_info_map) { - TensorInfo tensor_info = node_tensor_info.second; - TensorLayout tensor_layout = tensor_info.tensor_layout(); + TensorLayoutPtr tensor_layout = node_tensor_info.second; straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); MS_EXCEPTION_IF_NULL(parallel_layout_item); parallel_layout_item->set_param_name(node_tensor_info.first); straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts(); straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix(); MS_EXCEPTION_IF_NULL(dev_matrix); - for (auto dim : tensor_layout.device_arrangement().array()) { + for (auto dim : tensor_layout->device_arrangement().array()) { dev_matrix->add_dim(LongToUlong(dim)); } straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map(); MS_EXCEPTION_IF_NULL(tensor_map); - for (auto dim : tensor_layout.tensor_map().array()) { + for (auto dim : tensor_layout->tensor_map().array()) { tensor_map->add_dim(dim); } straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape(); @@ -165,7 +163,9 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf param_split_shape->add_dim(dim_pair.first); indices_offset->add_dim(dim_pair.second); } - parallel_layouts->set_field(tensor_layout.get_field_size()); + parallel_layouts->set_field(tensor_layout->get_field_size()); + parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step()); + parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size()); } std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h index 5048f21ab92..6c0be17bbc5 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/strategy.h" #include "frontend/parallel/context.h" @@ -30,7 +31,9 @@ namespace mindspore { namespace parallel { using StrategyMap = std::unordered_map; -using TensorInfoMap = std::unordered_map; +using TensorLayoutPtr = std::shared_ptr; +using TensorInfoMap = std::unordered_map; +using ParameterMap = std::vector>; using ManualShapeMap = std::unordered_map>>; using GroupInfoMap = std::vector>>; class StrategyCheckpoint { diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc index 3543a7520bb..cc03be776be 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc @@ -21,6 +21,7 @@ #include "ir/value.h" #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/status.h" +#include "frontend/parallel/context.h" #include "frontend/parallel/tensor_layout/shape_util.h" #include "utils/log_adapter.h" @@ -431,6 +432,10 @@ Status TensorLayout::GenerateOptShardSliceShape() { int64_t repeated_num = std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast(1), std::multiplies()); int64_t split_num; + int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size(); + if (optimizer_weight_shard_size != -1) { + repeated_num = optimizer_weight_shard_size; + } if (tensor_map[0] == MAP_NONE) { split_num = repeated_num; } else { diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h index 1d04ff74f99..0bef175c6cb 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -104,6 +104,18 @@ class TensorLayout { std::string opt_shard_group() { return opt_shard_group_; } + void set_opt_shard_mirror_group(std::string name) { opt_shard_mirror_group_ = std::move(name); } + + std::string opt_shard_mirror_group() { return opt_shard_mirror_group_; } + + void set_opt_weight_shard_step(int32_t step) { opt_weight_shard_step_ = step; } + + int32_t opt_weight_shard_step() { return opt_weight_shard_step_; } + + void set_opt_weight_shard_size(int32_t size) { opt_weight_shard_size_ = size; } + + int32_t opt_weight_shard_size() { return opt_weight_shard_size_; } + // Key for user data. constexpr static char key[] = "TLayout"; @@ -129,7 +141,10 @@ class TensorLayout { bool layout_transfer_ = false; int32_t field_size_ = 0; Shape opt_shard_slice_shape_; - std::string opt_shard_group_ = ""; + std::string opt_shard_group_ = ""; // for allgather + std::string opt_shard_mirror_group_ = ""; // for mirror ops + int32_t opt_weight_shard_step_ = 0; + int32_t opt_weight_shard_size_ = 0; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index a6f62c74811..dea07931f3e 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -173,6 +173,14 @@ PYBIND11_MODULE(_c_expression, m) { "Get enable/disable parallel optimizer.") .def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.") .def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.") + .def("set_optimizer_weight_shard_size", &ParallelContext::set_optimizer_weight_shard_size, + "Set opt shard group size when not fully use parallel optimizer.") + .def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size, + "Get opt shard group size when not fully use parallel optimizer.") + .def("set_optimizer_weight_shard_integrated_save", &ParallelContext::set_optimizer_weight_shard_integrated_save, + "Set whether to integrated save weight shard when enable parallel optimizer.") + .def("get_optimizer_weight_shard_integrated_save", &ParallelContext::optimizer_weight_shard_integrated_save, + "Get whether to integrated save weight shard when enable parallel optimizer.") .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); (void)py::class_>(m, "CostModelContext") diff --git a/mindspore/ccsrc/utils/node_strategy.proto b/mindspore/ccsrc/utils/node_strategy.proto index ffb8f4d87dd..a2b45e125eb 100644 --- a/mindspore/ccsrc/utils/node_strategy.proto +++ b/mindspore/ccsrc/utils/node_strategy.proto @@ -54,6 +54,8 @@ message ParallelLayouts { repeated ParamSplitShape param_split_shape = 3; repeated IndicesOffset indices_offset = 4; required int32 field = 5; + required int32 opt_weight_shard_step = 6; + required int32 opt_weight_shard_size = 7; } message ParallelLayoutItem { diff --git a/mindspore/core/utils/parallel_node_check.cc b/mindspore/core/utils/parallel_node_check.cc index 855ac58e784..c974752f8ae 100644 --- a/mindspore/core/utils/parallel_node_check.cc +++ b/mindspore/core/utils/parallel_node_check.cc @@ -14,11 +14,10 @@ * limitations under the License. */ -#include "utils/parallel_node_check.h" - #include #include +#include "utils/parallel_node_check.h" #include "base/core_ops.h" namespace mindspore { @@ -32,6 +31,7 @@ static const std::set PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", "stop_gradient", "Send", "UpdateState", "Load"}; +static const std::set ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather}; // clang-format on bool IsInParallelBlackList(const PrimitivePtr &prim) { @@ -39,6 +39,15 @@ bool IsInParallelBlackList(const PrimitivePtr &prim) { return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end()); } +bool IsInAllGatherNodeList(const CNodePtr &cnode) { + for (auto &value : ALLGATHER_NODE_LIST_) { + if (IsPrimitiveCNode(cnode, value)) { + return true; + } + } + return false; +} + bool IsParallelConsiderCNode(const CNodePtr &cnode) { if (cnode == nullptr || cnode->size() == 0) { return false; @@ -51,9 +60,6 @@ bool IsParallelConsiderCNode(const CNodePtr &cnode) { if (prim == nullptr) { return false; } - if (IsInParallelBlackList(prim)) { - return false; - } - return true; + return !IsInParallelBlackList(prim); } } // namespace mindspore diff --git a/mindspore/core/utils/parallel_node_check.h b/mindspore/core/utils/parallel_node_check.h index 491827d15c0..ea73da2b8a8 100644 --- a/mindspore/core/utils/parallel_node_check.h +++ b/mindspore/core/utils/parallel_node_check.h @@ -21,6 +21,7 @@ namespace mindspore { bool IsInParallelBlackList(const PrimitivePtr &); +bool IsInAllGatherNodeList(const CNodePtr &); bool IsParallelConsiderCNode(const CNodePtr &); } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 06fdf0bb94d..334c7b4ce90 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -15,6 +15,7 @@ """Context of auto parallel""" import threading import mindspore.context as context +import mindspore.log as logger from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore.parallel._ps_context import _is_role_pserver from mindspore._c_expression import AutoParallelContext @@ -501,6 +502,48 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_communi_parallel_mode() + def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size): + """ + Set optimizer_weight_shard_size. + + Args: + optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel + optimizer across devices. + """ + self.check_context_handle() + if not isinstance(optimizer_weight_shard_size, int): + raise TypeError('optimizer_weight_shard_size is invalid type') + if optimizer_weight_shard_size <= 1: + logger.warning("The setting 'optimizer_weight_shard_size' is invalid. " + "Please use the integer larger than 1.") + return + self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size) + + def get_optimizer_weight_shard_size(self): + """Get optimizer_weight_shard_size.""" + self.check_context_handle() + return self._context_handle.get_optimizer_weight_shard_size() + + def set_optimizer_weight_shard_integrated_save(self, optimizer_weight_shard_integrated_save): + """ + Set optimizer_weight_shard_integrated_save. + + Args: + optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when + enable parallel optimizer. + """ + self.check_context_handle() + if not isinstance(optimizer_weight_shard_integrated_save, bool): + raise TypeError('optimizer_weight_shard_integrated_save is invalid type') + self._context_handle.set_optimizer_weight_shard_integrated_save(optimizer_weight_shard_integrated_save) + + + def get_optimizer_weight_shard_integrated_save(self): + """Get optimizer_weight_shard_size.""" + self.check_context_handle() + return self._context_handle.get_optimizer_weight_shard_integrated_save() + + def reset(self): """Reset all settings.""" self.check_context_handle() @@ -540,7 +583,9 @@ _set_auto_parallel_context_func_map = { "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices, - "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode} + "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode, + "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size, + "optimizer_weight_shard_integrated_save": auto_parallel_context().set_optimizer_weight_shard_integrated_save} _get_auto_parallel_context_func_map = { @@ -559,7 +604,9 @@ _get_auto_parallel_context_func_map = { "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices, - "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode} + "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode, + "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size, + "optimizer_weight_shard_integrated_save": auto_parallel_context().get_optimizer_weight_shard_integrated_save} @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, @@ -567,7 +614,8 @@ _get_auto_parallel_context_func_map = { parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, - communi_parallel_mode=str) + communi_parallel_mode=str, optimizer_weight_shard_size=int, + optimizer_weight_shard_integrated_save=bool) def _set_auto_parallel_context(**kwargs): """ @@ -615,7 +663,7 @@ def _set_auto_parallel_context(**kwargs): pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how the devices are distributed alone the pipeline. The total devices will be divided into 'pipeline_stags' stages. This currently could only be used when - parall mode semi_auto_parallel is enabled. Default: 0 + parallel mode semi_auto_parallel is enabled. Default: 0 communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel", "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel". @@ -624,6 +672,11 @@ def _set_auto_parallel_context(**kwargs): - same_server_group_parallel: Only the communication groups within the same server are parallel. - no_group_parallel: All communication groups are not parallel. + optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer. + It should be larger than one and less than or equal with the data parallel size. + Default: -1, which means fully use parallel optimizer in data parallel dimension. + optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when enable parallel + optimizer. Default: False. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 088ed6fd388..2c596a0ca03 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -248,6 +248,10 @@ def _remove_repeated_slices(tensor_layout): def _infer_rank_list(train_map, predict_map=None): """infer checkpoint slices to be loaded""" ret = {} + if _get_pipeline_stages() > 1: + local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages())) + else: + local_rank = _get_global_rank() for param_name in train_map: train_layout = train_map[param_name] train_dev_mat = train_layout[0] @@ -271,15 +275,13 @@ def _infer_rank_list(train_map, predict_map=None): dev_num = np.array(predict_layout[0]).prod() # optimization pass if _check_same_layout(train_layout, predict_layout): - dev_rank = _get_global_rank() - ret[param_name] = ([dev_rank], True) + ret[param_name] = ([local_rank], True) continue if _check_similar_layout(train_layout, predict_layout): if len(rank_list) == 1: ret[param_name] = (rank_list, True) elif len(rank_list) == dev_num: - dev_rank = _get_global_rank() - ret[param_name] = ([rank_list[dev_rank]], True) + ret[param_name] = ([rank_list[local_rank]], True) else: ret[param_name] = (rank_list, False) else: diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index d069277335b..fa6b6dcc566 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -597,7 +597,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save): allgather_net = get_allgather_cell(opt_shard_group, False) net.parallel_parameter_merge_net_dict[param_name] = allgather_net param_data = allgather_net(param_data) - elif opt_shard_group: + elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_integrated_save"): if allgather_net is None: allgather_net = get_allgather_cell(opt_shard_group, False) net.parallel_parameter_merge_net_dict[param_name] = allgather_net @@ -1247,7 +1247,9 @@ def _convert_to_list(strategy): tensor_map = list(layout.tensor_map[0].dim) param_split_shape = list(layout.param_split_shape[0].dim) field_size = int(layout.field) - train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size] + shard_stride = int(layout.opt_weight_shard_step) + shard_size = int(layout.opt_weight_shard_size) + train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size, shard_stride, shard_size] except BaseException as e: raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.") return train_map diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py index de0d156f420..9f8ce85b096 100644 --- a/tests/ut/python/parallel/test_parallel_optimizer.py +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -131,6 +131,17 @@ def test_auto_parallel_momentum_5(): assert not param_dict["weight2"][5] +def test_auto_parallel_momentum_6(): + # test not fully use parallel optimizer with optimizer_weight_shard_size + # weight1 could not be shard and weight2 is repeated + context.set_auto_parallel_context(optimizer_weight_shard_size=2) + train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2))) + param_dict = train_network.parameter_layout_dict + # validate opt_shard_group + assert param_dict["weight1"][5].startswith("2") + assert param_dict["weight2"][5].startswith("2") + + def test_AdamWeightDecay(): """ test_AdamWeightDecay """ context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) diff --git a/tests/ut/python/parallel/test_set_auto_parallel_context.py b/tests/ut/python/parallel/test_set_auto_parallel_context.py index 5f6401718fa..e1e05cdfdc9 100644 --- a/tests/ut/python/parallel/test_set_auto_parallel_context.py +++ b/tests/ut/python/parallel/test_set_auto_parallel_context.py @@ -59,6 +59,10 @@ def test_set_auto_parallel_context(): parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() assert parameter_broadcast_is_set + auto_parallel_context().set_optimizer_weight_shard_integrated_save(True) + integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save() + assert integrated_save + with pytest.raises(ValueError): context.set_auto_parallel_context(device_num=0) @@ -105,6 +109,7 @@ def test_reset_auto_parallel_context(): parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() stage = auto_parallel_context().get_pipeline_stages() communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode") + integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save() assert device_num == 1 assert global_rank == 0 @@ -116,3 +121,4 @@ def test_reset_auto_parallel_context(): assert not parameter_broadcast_is_set assert stage == 1 assert communi_parallel_mode == "all_group_parallel" + assert not integrated_save