forked from mindspore-Ecosystem/mindspore
enable not fully use opt shard
This commit is contained in:
parent
ac5af72836
commit
2a752f24bf
|
@ -69,6 +69,8 @@ void ParallelContext::Reset() {
|
|||
pipeline_stage_split_num_ = 1;
|
||||
grad_accumulation_step_ = 1;
|
||||
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
|
||||
optimizer_weight_shard_size_ = -1;
|
||||
optimizer_weight_shard_integrated_save_ = false;
|
||||
}
|
||||
|
||||
void ParallelContext::set_device_num(int64_t device_num) {
|
||||
|
@ -132,6 +134,14 @@ void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_sav
|
|||
group_ckpt_save_file_ = group_ckpt_save_file;
|
||||
}
|
||||
|
||||
void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) {
|
||||
optimizer_weight_shard_size_ = optimizer_weight_shard_size;
|
||||
}
|
||||
|
||||
void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save) {
|
||||
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
|
||||
}
|
||||
|
||||
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
|
||||
all_reduce_fusion_split_indices_[group] = indices;
|
||||
}
|
||||
|
|
|
@ -95,6 +95,11 @@ class ParallelContext {
|
|||
bool global_rank_is_set() const { return global_rank_is_set_; }
|
||||
bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; }
|
||||
|
||||
void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);
|
||||
int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; }
|
||||
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
|
||||
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
|
||||
|
||||
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
|
||||
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
|
||||
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group);
|
||||
|
@ -152,6 +157,8 @@ class ParallelContext {
|
|||
bool enable_parallel_optimizer_;
|
||||
bool init_param_shape_;
|
||||
std::string communi_parallel_mode_;
|
||||
int64_t optimizer_weight_shard_size_;
|
||||
bool optimizer_weight_shard_integrated_save_;
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -473,6 +473,76 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OperatorInfo::CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *groups) {
|
||||
if (groups == nullptr) {
|
||||
MS_LOG(ERROR) << "The group is null. Operator is " << name_;
|
||||
return FAILED;
|
||||
}
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
|
||||
RankList group_devices;
|
||||
Shape tensor_map = tensor_layout->origin_tensor_map().array();
|
||||
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (group_devices.size() == 1) {
|
||||
MS_LOG(INFO) << "The dev size is 1, no need to create group.";
|
||||
return SUCCESS;
|
||||
}
|
||||
int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
|
||||
if (optimizer_weight_shard_size != -1) {
|
||||
// not fully use opt shard
|
||||
int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin();
|
||||
int64_t repeated_size = group_devices.size();
|
||||
if (repeated_size % optimizer_weight_shard_size != 0) {
|
||||
MS_LOG(WARNING) << "Parallel optimizer: optimizer_weight_shard_size " << optimizer_weight_shard_size
|
||||
<< " can not be applied. The repeated size of Operator " << name_ << " is " << repeated_size;
|
||||
return FAILED;
|
||||
}
|
||||
repeated_size = repeated_size / optimizer_weight_shard_size;
|
||||
// create allgather group
|
||||
// eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24]
|
||||
RankList new_group_devices(
|
||||
group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size,
|
||||
group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size);
|
||||
Group allgather_group = g_device_manager->CreateGroup(new_group_devices);
|
||||
groups->push_back(allgather_group);
|
||||
tensor_layout->set_opt_shard_group(allgather_group.name());
|
||||
MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
|
||||
// create mirror group
|
||||
// eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 16], [8, 24]
|
||||
int64_t device_num = g_device_manager->stage_device_num();
|
||||
Shape dev_mat = {repeated_size, device_num / repeated_size};
|
||||
DeviceMatrix temp_dev_matrix(rank, stage_device_list_, dev_mat);
|
||||
RankList mirror_group_devices;
|
||||
if (temp_dev_matrix.GetDevicesAlongDim(0, &mirror_group_devices) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
Group mirror_group = g_device_manager->CreateGroup(mirror_group_devices);
|
||||
groups->push_back(mirror_group);
|
||||
tensor_layout->set_opt_shard_mirror_group(mirror_group.name());
|
||||
MS_LOG(INFO) << "Parallel optimizer: create mirror group " << mirror_group.name();
|
||||
} else {
|
||||
// fully use opt shard
|
||||
// create allgather group
|
||||
Group allgather_group = g_device_manager->CreateGroup(group_devices);
|
||||
groups->push_back(allgather_group);
|
||||
tensor_layout->set_opt_shard_group(allgather_group.name());
|
||||
MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
|
||||
}
|
||||
// save in tensor_layout for strategy ckpt
|
||||
auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_integrated_save();
|
||||
if (!integrated_save) {
|
||||
tensor_layout->set_opt_weight_shard_size(optimizer_weight_shard_size);
|
||||
int32_t opt_weight_shard_step = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
|
||||
tensor_layout->set_opt_weight_shard_step(opt_weight_shard_step);
|
||||
MS_LOG(INFO) << "Parallel optimizer: save opt_weight_shard_step " << opt_weight_shard_step << " in strategy ckpt";
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
|
||||
if (group == nullptr) {
|
||||
MS_LOG(ERROR) << "The group is null.";
|
||||
|
|
|
@ -177,6 +177,7 @@ class OperatorInfo {
|
|||
void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
|
||||
int32_t stage_id() const { return stage_id_; }
|
||||
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
|
||||
Status CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *group);
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "OpInfo";
|
||||
|
|
|
@ -39,7 +39,6 @@
|
|||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/node_check.h"
|
||||
#include "frontend/parallel/ops_info/matmul_info.h"
|
||||
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/comm_manager.h"
|
||||
|
@ -1069,7 +1068,7 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
|
|||
<< param_v.size();
|
||||
}
|
||||
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
|
||||
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
|
||||
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
|
||||
return std::make_pair(nullptr, true);
|
||||
}
|
||||
return std::make_pair(node, true);
|
||||
|
@ -1077,6 +1076,14 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
|
|||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
|
||||
static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
||||
auto param_ptr = node->user_data<parallel::TensorLayout>();
|
||||
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
|
||||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
return std::make_pair(node, false);
|
||||
}
|
||||
|
||||
// Only used for InsertMirrorOps
|
||||
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
||||
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
|
||||
|
@ -1084,11 +1091,7 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|||
}
|
||||
|
||||
if (node->isa<Parameter>()) {
|
||||
auto param_ptr = node->user_data<parallel::TensorLayout>();
|
||||
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
|
||||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
return std::make_pair(node, false);
|
||||
return FindParameterByParameter(node, func_graph);
|
||||
}
|
||||
|
||||
if (node->isa<ValueNode>()) {
|
||||
|
@ -1109,8 +1112,9 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|||
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) {
|
||||
return std::make_pair(node, false);
|
||||
}
|
||||
|
||||
if (IsParallelCareNode(cnode)) {
|
||||
// When not fully use opt shard, allgather and mirror would be both inserted.
|
||||
// Skip allgather here and find parameter recursively.
|
||||
if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
|
||||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
|
||||
|
@ -1238,10 +1242,17 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
|
||||
auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
|
||||
std::string param_name;
|
||||
if (param_ptr != nullptr) {
|
||||
if (param_ptr) {
|
||||
param_name = param_ptr->name();
|
||||
std::string opt_shard_mirror_group;
|
||||
if (param_ptr->user_data<TensorLayout>()) {
|
||||
opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
|
||||
}
|
||||
if (!opt_shard_mirror_group.empty()) {
|
||||
// mirror ops is covered in not fully use opt shard case
|
||||
backward_op = CreateMirrorOps(opt_shard_mirror_group, static_cast<size_t>(opt_shard_mirror_group[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// not a RefKey
|
||||
if (!param_node_pair.second) {
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
|
@ -1275,8 +1286,8 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
}
|
||||
std::string instance_name = MIRROR_OP;
|
||||
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
|
||||
auto op = backward_op[0];
|
||||
if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) {
|
||||
for (auto &op : backward_op) {
|
||||
// insert new node before the node
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfNodePtr pre_node = cnode->input(1);
|
||||
|
@ -1284,10 +1295,8 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
|
||||
// add fusion flag
|
||||
AddCommOpFusionType(comm_op, param_node_pair.first);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
for (auto &op : backward_op) {
|
||||
AnfNodePtr pre_node = node->input(index);
|
||||
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
|
||||
auto comm_op = node->input(index)->cast<CNodePtr>();
|
||||
|
@ -1295,7 +1304,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
// pipeline mirror would not be set, which should be supported later
|
||||
AddCommOpFusionType(comm_op, param_node_pair.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||
|
@ -1695,7 +1703,11 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|||
manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
|
||||
<< GetPrimName(cnode);
|
||||
continue;
|
||||
} else {
|
||||
// insert allgather operator between shard parameter and cnode
|
||||
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
|
||||
<< GetPrimName(cnode);
|
||||
}
|
||||
} else {
|
||||
// insert allgather operator between shard parameter and cnode
|
||||
|
@ -1708,6 +1720,35 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|||
}
|
||||
}
|
||||
|
||||
static std::string GetOptShardGroup(const AnfNodePtr ¶meter, TensorLayout *const tensor_layout,
|
||||
const OperatorInfoPtr &distribute_operator) {
|
||||
std::string opt_shard_group;
|
||||
if (!ParameterRequireGrad(parameter)) {
|
||||
// only trainable parameters need parallel optimizer
|
||||
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
|
||||
} else if (parameter->cast<ParameterPtr>()->param_info() &&
|
||||
!parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
|
||||
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
|
||||
} else if (tensor_layout->GenerateOptShardSliceShape() == Status::SUCCESS) {
|
||||
// get the shard tensor slice shape if the weight is repeated on devices
|
||||
// and the shape of the first dimension could be divided
|
||||
// apply parallel optimizer on parameters
|
||||
// create communication group for allgather operator
|
||||
std::vector<Group> dev_group;
|
||||
if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS &&
|
||||
!dev_group.empty()) {
|
||||
opt_shard_group = dev_group[0].name();
|
||||
MS_LOG(INFO) << "Parallel optimizer: create group for " << parameter->ToString() << " success.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Parallel optimizer: " << parameter->ToString() << "'s distributed shape "
|
||||
<< tensor_layout->slice_shape().ToString() << " does not satisfy the conditions.";
|
||||
}
|
||||
return opt_shard_group;
|
||||
}
|
||||
|
||||
// When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
|
||||
std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int64_t> &res) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
|
@ -1731,33 +1772,10 @@ std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNod
|
|||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
|
||||
if (enable_parallel_optimizer) {
|
||||
if (!ParameterRequireGrad(parameter)) {
|
||||
// only trainable parameters need parallel optimizer
|
||||
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
|
||||
} else if (parameter->cast<ParameterPtr>()->param_info() &&
|
||||
!parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
|
||||
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
|
||||
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
|
||||
// get a totally shard tensor slice shape if the weight is repeated on devices
|
||||
// and the shape of the first dimension could be divided
|
||||
// apply parallel optimizer on parameters
|
||||
// create communication group for allgather operator
|
||||
opt_shard_group = GetOptShardGroup(parameter, &tensor_layout, distribute_operator);
|
||||
}
|
||||
if (!opt_shard_group.empty()) {
|
||||
slice_shape = tensor_layout.opt_shard_slice_shape();
|
||||
std::vector<Group> dev_group;
|
||||
if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) ==
|
||||
Status::SUCCESS &&
|
||||
!dev_group.empty()) {
|
||||
opt_shard_group = dev_group[0].name();
|
||||
// set communication group in tensor layout for checkpoint saving
|
||||
tensor_layout.set_opt_shard_group(opt_shard_group);
|
||||
MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString()
|
||||
<< " success.";
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions.";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
|
||||
<< MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name();
|
||||
|
@ -2812,21 +2830,21 @@ bool IsCohesiveNode(const CNodePtr &cnode) {
|
|||
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather);
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
|
||||
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
|
||||
if (curr_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: "
|
||||
<< MAX_RECURSIVE_DEPTH;
|
||||
return {};
|
||||
}
|
||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||
std::vector<std::pair<std::string, int64_t>> param_names;
|
||||
ParameterMap param_names;
|
||||
for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
|
||||
int64_t idx = index > i ? index : i;
|
||||
auto input = node_inputs[i];
|
||||
if (input->isa<Parameter>()) {
|
||||
auto input_parameter = input->cast<ParameterPtr>();
|
||||
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
|
||||
param_names.push_back({input_parameter->name(), idx});
|
||||
param_names.push_back({input_parameter->name(), input_parameter});
|
||||
}
|
||||
} else if (input->isa<CNode>()) {
|
||||
CNodePtr cnode = input->cast<CNodePtr>();
|
||||
|
@ -2878,10 +2896,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
std::string stratey_key_name = prim->name() + "_" + param_name;
|
||||
stra_map[stratey_key_name] = operator_info->strategy();
|
||||
for (auto param_name_pair : param_names) {
|
||||
if (param_name_pair.second - 1 >= UlongToLong(input_tensor_info.size())) {
|
||||
continue;
|
||||
}
|
||||
tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1];
|
||||
tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>();
|
||||
}
|
||||
if (IsGatherPInfo(operator_info->name())) {
|
||||
auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info);
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "pipeline/jit/pipeline.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
|
||||
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
|
||||
|
||||
|
@ -139,7 +140,7 @@ bool IsLastStage();
|
|||
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||
const FuncGraphManagerPtr &manager);
|
||||
|
||||
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
|
||||
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
|
||||
|
||||
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/ms_utils.h"
|
||||
|
@ -141,20 +140,19 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
}
|
||||
}
|
||||
for (auto &node_tensor_info : tensor_info_map) {
|
||||
TensorInfo tensor_info = node_tensor_info.second;
|
||||
TensorLayout tensor_layout = tensor_info.tensor_layout();
|
||||
TensorLayoutPtr tensor_layout = node_tensor_info.second;
|
||||
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
|
||||
MS_EXCEPTION_IF_NULL(parallel_layout_item);
|
||||
parallel_layout_item->set_param_name(node_tensor_info.first);
|
||||
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
|
||||
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
|
||||
MS_EXCEPTION_IF_NULL(dev_matrix);
|
||||
for (auto dim : tensor_layout.device_arrangement().array()) {
|
||||
for (auto dim : tensor_layout->device_arrangement().array()) {
|
||||
dev_matrix->add_dim(LongToUlong(dim));
|
||||
}
|
||||
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
|
||||
MS_EXCEPTION_IF_NULL(tensor_map);
|
||||
for (auto dim : tensor_layout.tensor_map().array()) {
|
||||
for (auto dim : tensor_layout->tensor_map().array()) {
|
||||
tensor_map->add_dim(dim);
|
||||
}
|
||||
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
|
||||
|
@ -165,7 +163,9 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
param_split_shape->add_dim(dim_pair.first);
|
||||
indices_offset->add_dim(dim_pair.second);
|
||||
}
|
||||
parallel_layouts->set_field(tensor_layout.get_field_size());
|
||||
parallel_layouts->set_field(tensor_layout->get_field_size());
|
||||
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
|
||||
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
|
||||
}
|
||||
|
||||
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
|
@ -30,7 +31,9 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
|
||||
using TensorInfoMap = std::unordered_map<std::string, TensorInfo>;
|
||||
using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
|
||||
using TensorInfoMap = std::unordered_map<std::string, TensorLayoutPtr>;
|
||||
using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>;
|
||||
using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>;
|
||||
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
|
||||
class StrategyCheckpoint {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "ir/value.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/tensor_layout/shape_util.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -431,6 +432,10 @@ Status TensorLayout::GenerateOptShardSliceShape() {
|
|||
int64_t repeated_num =
|
||||
std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
int64_t split_num;
|
||||
int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
|
||||
if (optimizer_weight_shard_size != -1) {
|
||||
repeated_num = optimizer_weight_shard_size;
|
||||
}
|
||||
if (tensor_map[0] == MAP_NONE) {
|
||||
split_num = repeated_num;
|
||||
} else {
|
||||
|
|
|
@ -104,6 +104,18 @@ class TensorLayout {
|
|||
|
||||
std::string opt_shard_group() { return opt_shard_group_; }
|
||||
|
||||
void set_opt_shard_mirror_group(std::string name) { opt_shard_mirror_group_ = std::move(name); }
|
||||
|
||||
std::string opt_shard_mirror_group() { return opt_shard_mirror_group_; }
|
||||
|
||||
void set_opt_weight_shard_step(int32_t step) { opt_weight_shard_step_ = step; }
|
||||
|
||||
int32_t opt_weight_shard_step() { return opt_weight_shard_step_; }
|
||||
|
||||
void set_opt_weight_shard_size(int32_t size) { opt_weight_shard_size_ = size; }
|
||||
|
||||
int32_t opt_weight_shard_size() { return opt_weight_shard_size_; }
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "TLayout";
|
||||
|
||||
|
@ -129,7 +141,10 @@ class TensorLayout {
|
|||
bool layout_transfer_ = false;
|
||||
int32_t field_size_ = 0;
|
||||
Shape opt_shard_slice_shape_;
|
||||
std::string opt_shard_group_ = "";
|
||||
std::string opt_shard_group_ = ""; // for allgather
|
||||
std::string opt_shard_mirror_group_ = ""; // for mirror ops
|
||||
int32_t opt_weight_shard_step_ = 0;
|
||||
int32_t opt_weight_shard_size_ = 0;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -173,6 +173,14 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Get enable/disable parallel optimizer.")
|
||||
.def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.")
|
||||
.def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.")
|
||||
.def("set_optimizer_weight_shard_size", &ParallelContext::set_optimizer_weight_shard_size,
|
||||
"Set opt shard group size when not fully use parallel optimizer.")
|
||||
.def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size,
|
||||
"Get opt shard group size when not fully use parallel optimizer.")
|
||||
.def("set_optimizer_weight_shard_integrated_save", &ParallelContext::set_optimizer_weight_shard_integrated_save,
|
||||
"Set whether to integrated save weight shard when enable parallel optimizer.")
|
||||
.def("get_optimizer_weight_shard_integrated_save", &ParallelContext::optimizer_weight_shard_integrated_save,
|
||||
"Get whether to integrated save weight shard when enable parallel optimizer.")
|
||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||
|
||||
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
||||
|
|
|
@ -54,6 +54,8 @@ message ParallelLayouts {
|
|||
repeated ParamSplitShape param_split_shape = 3;
|
||||
repeated IndicesOffset indices_offset = 4;
|
||||
required int32 field = 5;
|
||||
required int32 opt_weight_shard_step = 6;
|
||||
required int32 opt_weight_shard_size = 7;
|
||||
}
|
||||
|
||||
message ParallelLayoutItem {
|
||||
|
|
|
@ -14,11 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/parallel_node_check.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "utils/parallel_node_check.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -32,6 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "Send", "UpdateState", "Load"};
|
||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather};
|
||||
// clang-format on
|
||||
|
||||
bool IsInParallelBlackList(const PrimitivePtr &prim) {
|
||||
|
@ -39,6 +39,15 @@ bool IsInParallelBlackList(const PrimitivePtr &prim) {
|
|||
return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end());
|
||||
}
|
||||
|
||||
bool IsInAllGatherNodeList(const CNodePtr &cnode) {
|
||||
for (auto &value : ALLGATHER_NODE_LIST_) {
|
||||
if (IsPrimitiveCNode(cnode, value)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsParallelConsiderCNode(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return false;
|
||||
|
@ -51,9 +60,6 @@ bool IsParallelConsiderCNode(const CNodePtr &cnode) {
|
|||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (IsInParallelBlackList(prim)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return !IsInParallelBlackList(prim);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
bool IsInParallelBlackList(const PrimitivePtr &);
|
||||
bool IsInAllGatherNodeList(const CNodePtr &);
|
||||
bool IsParallelConsiderCNode(const CNodePtr &);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Context of auto parallel"""
|
||||
import threading
|
||||
import mindspore.context as context
|
||||
import mindspore.log as logger
|
||||
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
||||
from mindspore.parallel._ps_context import _is_role_pserver
|
||||
from mindspore._c_expression import AutoParallelContext
|
||||
|
@ -501,6 +502,48 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_communi_parallel_mode()
|
||||
|
||||
def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size):
|
||||
"""
|
||||
Set optimizer_weight_shard_size.
|
||||
|
||||
Args:
|
||||
optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel
|
||||
optimizer across devices.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not isinstance(optimizer_weight_shard_size, int):
|
||||
raise TypeError('optimizer_weight_shard_size is invalid type')
|
||||
if optimizer_weight_shard_size <= 1:
|
||||
logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
|
||||
"Please use the integer larger than 1.")
|
||||
return
|
||||
self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)
|
||||
|
||||
def get_optimizer_weight_shard_size(self):
|
||||
"""Get optimizer_weight_shard_size."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_optimizer_weight_shard_size()
|
||||
|
||||
def set_optimizer_weight_shard_integrated_save(self, optimizer_weight_shard_integrated_save):
|
||||
"""
|
||||
Set optimizer_weight_shard_integrated_save.
|
||||
|
||||
Args:
|
||||
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when
|
||||
enable parallel optimizer.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not isinstance(optimizer_weight_shard_integrated_save, bool):
|
||||
raise TypeError('optimizer_weight_shard_integrated_save is invalid type')
|
||||
self._context_handle.set_optimizer_weight_shard_integrated_save(optimizer_weight_shard_integrated_save)
|
||||
|
||||
|
||||
def get_optimizer_weight_shard_integrated_save(self):
|
||||
"""Get optimizer_weight_shard_size."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_optimizer_weight_shard_integrated_save()
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset all settings."""
|
||||
self.check_context_handle()
|
||||
|
@ -540,7 +583,9 @@ _set_auto_parallel_context_func_map = {
|
|||
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
||||
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
||||
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
|
||||
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode}
|
||||
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
|
||||
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_integrated_save": auto_parallel_context().set_optimizer_weight_shard_integrated_save}
|
||||
|
||||
|
||||
_get_auto_parallel_context_func_map = {
|
||||
|
@ -559,7 +604,9 @@ _get_auto_parallel_context_func_map = {
|
|||
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
|
||||
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
|
||||
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
|
||||
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode}
|
||||
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
|
||||
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_integrated_save": auto_parallel_context().get_optimizer_weight_shard_integrated_save}
|
||||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
||||
|
@ -567,7 +614,8 @@ _get_auto_parallel_context_func_map = {
|
|||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
||||
communi_parallel_mode=str)
|
||||
communi_parallel_mode=str, optimizer_weight_shard_size=int,
|
||||
optimizer_weight_shard_integrated_save=bool)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
|
@ -615,7 +663,7 @@ def _set_auto_parallel_context(**kwargs):
|
|||
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
|
||||
the devices are distributed alone the pipeline. The total devices will be divided into
|
||||
'pipeline_stags' stages. This currently could only be used when
|
||||
parall mode semi_auto_parallel is enabled. Default: 0
|
||||
parallel mode semi_auto_parallel is enabled. Default: 0
|
||||
communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
|
||||
"same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
|
||||
|
||||
|
@ -624,6 +672,11 @@ def _set_auto_parallel_context(**kwargs):
|
|||
- same_server_group_parallel: Only the communication groups within the same server are parallel.
|
||||
|
||||
- no_group_parallel: All communication groups are not parallel.
|
||||
optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
|
||||
It should be larger than one and less than or equal with the data parallel size.
|
||||
Default: -1, which means fully use parallel optimizer in data parallel dimension.
|
||||
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when enable parallel
|
||||
optimizer. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
|
|
@ -248,6 +248,10 @@ def _remove_repeated_slices(tensor_layout):
|
|||
def _infer_rank_list(train_map, predict_map=None):
|
||||
"""infer checkpoint slices to be loaded"""
|
||||
ret = {}
|
||||
if _get_pipeline_stages() > 1:
|
||||
local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages()))
|
||||
else:
|
||||
local_rank = _get_global_rank()
|
||||
for param_name in train_map:
|
||||
train_layout = train_map[param_name]
|
||||
train_dev_mat = train_layout[0]
|
||||
|
@ -271,15 +275,13 @@ def _infer_rank_list(train_map, predict_map=None):
|
|||
dev_num = np.array(predict_layout[0]).prod()
|
||||
# optimization pass
|
||||
if _check_same_layout(train_layout, predict_layout):
|
||||
dev_rank = _get_global_rank()
|
||||
ret[param_name] = ([dev_rank], True)
|
||||
ret[param_name] = ([local_rank], True)
|
||||
continue
|
||||
if _check_similar_layout(train_layout, predict_layout):
|
||||
if len(rank_list) == 1:
|
||||
ret[param_name] = (rank_list, True)
|
||||
elif len(rank_list) == dev_num:
|
||||
dev_rank = _get_global_rank()
|
||||
ret[param_name] = ([rank_list[dev_rank]], True)
|
||||
ret[param_name] = ([rank_list[local_rank]], True)
|
||||
else:
|
||||
ret[param_name] = (rank_list, False)
|
||||
else:
|
||||
|
|
|
@ -597,7 +597,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
|
|||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
param_data = allgather_net(param_data)
|
||||
elif opt_shard_group:
|
||||
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_integrated_save"):
|
||||
if allgather_net is None:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
|
@ -1247,7 +1247,9 @@ def _convert_to_list(strategy):
|
|||
tensor_map = list(layout.tensor_map[0].dim)
|
||||
param_split_shape = list(layout.param_split_shape[0].dim)
|
||||
field_size = int(layout.field)
|
||||
train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size]
|
||||
shard_stride = int(layout.opt_weight_shard_step)
|
||||
shard_size = int(layout.opt_weight_shard_size)
|
||||
train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size, shard_stride, shard_size]
|
||||
except BaseException as e:
|
||||
raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.")
|
||||
return train_map
|
||||
|
|
|
@ -131,6 +131,17 @@ def test_auto_parallel_momentum_5():
|
|||
assert not param_dict["weight2"][5]
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_6():
|
||||
# test not fully use parallel optimizer with optimizer_weight_shard_size
|
||||
# weight1 could not be shard and weight2 is repeated
|
||||
context.set_auto_parallel_context(optimizer_weight_shard_size=2)
|
||||
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
|
||||
param_dict = train_network.parameter_layout_dict
|
||||
# validate opt_shard_group
|
||||
assert param_dict["weight1"][5].startswith("2")
|
||||
assert param_dict["weight2"][5].startswith("2")
|
||||
|
||||
|
||||
def test_AdamWeightDecay():
|
||||
""" test_AdamWeightDecay """
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||
|
|
|
@ -59,6 +59,10 @@ def test_set_auto_parallel_context():
|
|||
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
||||
assert parameter_broadcast_is_set
|
||||
|
||||
auto_parallel_context().set_optimizer_weight_shard_integrated_save(True)
|
||||
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()
|
||||
assert integrated_save
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
context.set_auto_parallel_context(device_num=0)
|
||||
|
||||
|
@ -105,6 +109,7 @@ def test_reset_auto_parallel_context():
|
|||
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
||||
stage = auto_parallel_context().get_pipeline_stages()
|
||||
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
|
||||
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()
|
||||
|
||||
assert device_num == 1
|
||||
assert global_rank == 0
|
||||
|
@ -116,3 +121,4 @@ def test_reset_auto_parallel_context():
|
|||
assert not parameter_broadcast_is_set
|
||||
assert stage == 1
|
||||
assert communi_parallel_mode == "all_group_parallel"
|
||||
assert not integrated_save
|
||||
|
|
Loading…
Reference in New Issue