enable not fully use opt shard

This commit is contained in:
Ziyan 2021-02-22 16:01:19 +08:00
parent ac5af72836
commit 2a752f24bf
19 changed files with 306 additions and 88 deletions

View File

@ -69,6 +69,8 @@ void ParallelContext::Reset() {
pipeline_stage_split_num_ = 1;
grad_accumulation_step_ = 1;
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
optimizer_weight_shard_size_ = -1;
optimizer_weight_shard_integrated_save_ = false;
}
void ParallelContext::set_device_num(int64_t device_num) {
@ -132,6 +134,14 @@ void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_sav
group_ckpt_save_file_ = group_ckpt_save_file;
}
void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) {
optimizer_weight_shard_size_ = optimizer_weight_shard_size;
}
void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save) {
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
}
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
all_reduce_fusion_split_indices_[group] = indices;
}

View File

@ -95,6 +95,11 @@ class ParallelContext {
bool global_rank_is_set() const { return global_rank_is_set_; }
bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; }
void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);
int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; }
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group);
@ -152,6 +157,8 @@ class ParallelContext {
bool enable_parallel_optimizer_;
bool init_param_shape_;
std::string communi_parallel_mode_;
int64_t optimizer_weight_shard_size_;
bool optimizer_weight_shard_integrated_save_;
};
} // namespace parallel

View File

@ -473,6 +473,76 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector
return SUCCESS;
}
Status OperatorInfo::CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *groups) {
if (groups == nullptr) {
MS_LOG(ERROR) << "The group is null. Operator is " << name_;
return FAILED;
}
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
RankList group_devices;
Shape tensor_map = tensor_layout->origin_tensor_map().array();
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
return FAILED;
}
if (group_devices.size() == 1) {
MS_LOG(INFO) << "The dev size is 1, no need to create group.";
return SUCCESS;
}
int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
if (optimizer_weight_shard_size != -1) {
// not fully use opt shard
int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin();
int64_t repeated_size = group_devices.size();
if (repeated_size % optimizer_weight_shard_size != 0) {
MS_LOG(WARNING) << "Parallel optimizer: optimizer_weight_shard_size " << optimizer_weight_shard_size
<< " can not be applied. The repeated size of Operator " << name_ << " is " << repeated_size;
return FAILED;
}
repeated_size = repeated_size / optimizer_weight_shard_size;
// create allgather group
// eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24]
RankList new_group_devices(
group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size,
group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size);
Group allgather_group = g_device_manager->CreateGroup(new_group_devices);
groups->push_back(allgather_group);
tensor_layout->set_opt_shard_group(allgather_group.name());
MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
// create mirror group
// eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 16], [8, 24]
int64_t device_num = g_device_manager->stage_device_num();
Shape dev_mat = {repeated_size, device_num / repeated_size};
DeviceMatrix temp_dev_matrix(rank, stage_device_list_, dev_mat);
RankList mirror_group_devices;
if (temp_dev_matrix.GetDevicesAlongDim(0, &mirror_group_devices) != SUCCESS) {
return FAILED;
}
Group mirror_group = g_device_manager->CreateGroup(mirror_group_devices);
groups->push_back(mirror_group);
tensor_layout->set_opt_shard_mirror_group(mirror_group.name());
MS_LOG(INFO) << "Parallel optimizer: create mirror group " << mirror_group.name();
} else {
// fully use opt shard
// create allgather group
Group allgather_group = g_device_manager->CreateGroup(group_devices);
groups->push_back(allgather_group);
tensor_layout->set_opt_shard_group(allgather_group.name());
MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
}
// save in tensor_layout for strategy ckpt
auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_integrated_save();
if (!integrated_save) {
tensor_layout->set_opt_weight_shard_size(optimizer_weight_shard_size);
int32_t opt_weight_shard_step = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
tensor_layout->set_opt_weight_shard_step(opt_weight_shard_step);
MS_LOG(INFO) << "Parallel optimizer: save opt_weight_shard_step " << opt_weight_shard_step << " in strategy ckpt";
}
return SUCCESS;
}
Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
if (group == nullptr) {
MS_LOG(ERROR) << "The group is null.";

View File

@ -177,6 +177,7 @@ class OperatorInfo {
void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
int32_t stage_id() const { return stage_id_; }
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
Status CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *group);
// Key for user data.
constexpr static char key[] = "OpInfo";

View File

@ -39,7 +39,6 @@
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/node_check.h"
#include "frontend/parallel/ops_info/matmul_info.h"
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "ir/param_info.h"
#include "ir/tensor.h"
#include "utils/comm_manager.h"
@ -1069,7 +1068,7 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
<< param_v.size();
}
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
return std::make_pair(nullptr, true);
}
return std::make_pair(node, true);
@ -1077,6 +1076,14 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
return std::make_pair(nullptr, false);
}
static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
return std::make_pair(nullptr, false);
}
return std::make_pair(node, false);
}
// Only used for InsertMirrorOps
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
@ -1084,11 +1091,7 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
}
if (node->isa<Parameter>()) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, false);
}
return std::make_pair(node, false);
return FindParameterByParameter(node, func_graph);
}
if (node->isa<ValueNode>()) {
@ -1109,8 +1112,9 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) {
return std::make_pair(node, false);
}
if (IsParallelCareNode(cnode)) {
// When not fully use opt shard, allgather and mirror would be both inserted.
// Skip allgather here and find parameter recursively.
if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
return std::make_pair(nullptr, false);
}
@ -1238,10 +1242,17 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
std::string param_name;
if (param_ptr != nullptr) {
if (param_ptr) {
param_name = param_ptr->name();
std::string opt_shard_mirror_group;
if (param_ptr->user_data<TensorLayout>()) {
opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
}
if (!opt_shard_mirror_group.empty()) {
// mirror ops is covered in not fully use opt shard case
backward_op = CreateMirrorOps(opt_shard_mirror_group, static_cast<size_t>(opt_shard_mirror_group[0]));
}
}
// not a RefKey
if (!param_node_pair.second) {
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
@ -1275,26 +1286,23 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
}
std::string instance_name = MIRROR_OP;
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
auto op = backward_op[0];
if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) {
for (auto &op : backward_op) {
// insert new node before the node
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr pre_node = cnode->input(1);
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
// add fusion flag
AddCommOpFusionType(comm_op, param_node_pair.first);
}
// insert new node before the node
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr pre_node = cnode->input(1);
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
// add fusion flag
AddCommOpFusionType(comm_op, param_node_pair.first);
continue;
}
for (auto &op : backward_op) {
AnfNodePtr pre_node = node->input(index);
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
auto comm_op = node->input(index)->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first);
}
AnfNodePtr pre_node = node->input(index);
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
auto comm_op = node->input(index)->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first);
}
}
@ -1695,7 +1703,11 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
<< GetPrimName(cnode);
continue;
} else {
// insert allgather operator between shard parameter and cnode
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
<< GetPrimName(cnode);
}
} else {
// insert allgather operator between shard parameter and cnode
@ -1708,6 +1720,35 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
}
}
static std::string GetOptShardGroup(const AnfNodePtr &parameter, TensorLayout *const tensor_layout,
const OperatorInfoPtr &distribute_operator) {
std::string opt_shard_group;
if (!ParameterRequireGrad(parameter)) {
// only trainable parameters need parallel optimizer
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
} else if (parameter->cast<ParameterPtr>()->param_info() &&
!parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
} else if (tensor_layout->GenerateOptShardSliceShape() == Status::SUCCESS) {
// get the shard tensor slice shape if the weight is repeated on devices
// and the shape of the first dimension could be divided
// apply parallel optimizer on parameters
// create communication group for allgather operator
std::vector<Group> dev_group;
if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS &&
!dev_group.empty()) {
opt_shard_group = dev_group[0].name();
MS_LOG(INFO) << "Parallel optimizer: create group for " << parameter->ToString() << " success.";
} else {
MS_LOG(ERROR) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
}
} else {
MS_LOG(WARNING) << "Parallel optimizer: " << parameter->ToString() << "'s distributed shape "
<< tensor_layout->slice_shape().ToString() << " does not satisfy the conditions.";
}
return opt_shard_group;
}
// When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int64_t> &res) {
MS_EXCEPTION_IF_NULL(parameter);
@ -1731,33 +1772,10 @@ std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNod
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
if (enable_parallel_optimizer) {
if (!ParameterRequireGrad(parameter)) {
// only trainable parameters need parallel optimizer
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
} else if (parameter->cast<ParameterPtr>()->param_info() &&
!parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
// get a totally shard tensor slice shape if the weight is repeated on devices
// and the shape of the first dimension could be divided
// apply parallel optimizer on parameters
// create communication group for allgather operator
slice_shape = tensor_layout.opt_shard_slice_shape();
std::vector<Group> dev_group;
if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) ==
Status::SUCCESS &&
!dev_group.empty()) {
opt_shard_group = dev_group[0].name();
// set communication group in tensor layout for checkpoint saving
tensor_layout.set_opt_shard_group(opt_shard_group);
MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString()
<< " success.";
} else {
MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
}
} else {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions.";
}
opt_shard_group = GetOptShardGroup(parameter, &tensor_layout, distribute_operator);
}
if (!opt_shard_group.empty()) {
slice_shape = tensor_layout.opt_shard_slice_shape();
}
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
<< MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name();
@ -2812,21 +2830,21 @@ bool IsCohesiveNode(const CNodePtr &cnode) {
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather);
}
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
if (curr_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: "
<< MAX_RECURSIVE_DEPTH;
return {};
}
std::vector<AnfNodePtr> node_inputs{node->inputs()};
std::vector<std::pair<std::string, int64_t>> param_names;
ParameterMap param_names;
for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
int64_t idx = index > i ? index : i;
auto input = node_inputs[i];
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
param_names.push_back({input_parameter->name(), idx});
param_names.push_back({input_parameter->name(), input_parameter});
}
} else if (input->isa<CNode>()) {
CNodePtr cnode = input->cast<CNodePtr>();
@ -2878,10 +2896,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
std::string stratey_key_name = prim->name() + "_" + param_name;
stra_map[stratey_key_name] = operator_info->strategy();
for (auto param_name_pair : param_names) {
if (param_name_pair.second - 1 >= UlongToLong(input_tensor_info.size())) {
continue;
}
tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1];
tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>();
}
if (IsGatherPInfo(operator_info->name())) {
auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info);

View File

@ -32,6 +32,7 @@
#include "pipeline/jit/pipeline.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
@ -139,7 +140,7 @@ bool IsLastStage();
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager);
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);

View File

@ -17,7 +17,6 @@
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include <fstream>
#include <memory>
#include <vector>
#include "utils/ms_utils.h"
@ -141,20 +140,19 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
}
}
for (auto &node_tensor_info : tensor_info_map) {
TensorInfo tensor_info = node_tensor_info.second;
TensorLayout tensor_layout = tensor_info.tensor_layout();
TensorLayoutPtr tensor_layout = node_tensor_info.second;
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
MS_EXCEPTION_IF_NULL(parallel_layout_item);
parallel_layout_item->set_param_name(node_tensor_info.first);
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
MS_EXCEPTION_IF_NULL(dev_matrix);
for (auto dim : tensor_layout.device_arrangement().array()) {
for (auto dim : tensor_layout->device_arrangement().array()) {
dev_matrix->add_dim(LongToUlong(dim));
}
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
MS_EXCEPTION_IF_NULL(tensor_map);
for (auto dim : tensor_layout.tensor_map().array()) {
for (auto dim : tensor_layout->tensor_map().array()) {
tensor_map->add_dim(dim);
}
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
@ -165,7 +163,9 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
param_split_shape->add_dim(dim_pair.first);
indices_offset->add_dim(dim_pair.second);
}
parallel_layouts->set_field(tensor_layout.get_field_size());
parallel_layouts->set_field(tensor_layout->get_field_size());
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
}
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);

View File

@ -21,6 +21,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <memory>
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/context.h"
@ -30,7 +31,9 @@
namespace mindspore {
namespace parallel {
using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
using TensorInfoMap = std::unordered_map<std::string, TensorInfo>;
using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
using TensorInfoMap = std::unordered_map<std::string, TensorLayoutPtr>;
using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>;
using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>;
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
class StrategyCheckpoint {

View File

@ -21,6 +21,7 @@
#include "ir/value.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/status.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/tensor_layout/shape_util.h"
#include "utils/log_adapter.h"
@ -431,6 +432,10 @@ Status TensorLayout::GenerateOptShardSliceShape() {
int64_t repeated_num =
std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
int64_t split_num;
int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
if (optimizer_weight_shard_size != -1) {
repeated_num = optimizer_weight_shard_size;
}
if (tensor_map[0] == MAP_NONE) {
split_num = repeated_num;
} else {

View File

@ -104,6 +104,18 @@ class TensorLayout {
std::string opt_shard_group() { return opt_shard_group_; }
void set_opt_shard_mirror_group(std::string name) { opt_shard_mirror_group_ = std::move(name); }
std::string opt_shard_mirror_group() { return opt_shard_mirror_group_; }
void set_opt_weight_shard_step(int32_t step) { opt_weight_shard_step_ = step; }
int32_t opt_weight_shard_step() { return opt_weight_shard_step_; }
void set_opt_weight_shard_size(int32_t size) { opt_weight_shard_size_ = size; }
int32_t opt_weight_shard_size() { return opt_weight_shard_size_; }
// Key for user data.
constexpr static char key[] = "TLayout";
@ -129,7 +141,10 @@ class TensorLayout {
bool layout_transfer_ = false;
int32_t field_size_ = 0;
Shape opt_shard_slice_shape_;
std::string opt_shard_group_ = "";
std::string opt_shard_group_ = ""; // for allgather
std::string opt_shard_mirror_group_ = ""; // for mirror ops
int32_t opt_weight_shard_step_ = 0;
int32_t opt_weight_shard_size_ = 0;
};
} // namespace parallel
} // namespace mindspore

View File

@ -173,6 +173,14 @@ PYBIND11_MODULE(_c_expression, m) {
"Get enable/disable parallel optimizer.")
.def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.")
.def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.")
.def("set_optimizer_weight_shard_size", &ParallelContext::set_optimizer_weight_shard_size,
"Set opt shard group size when not fully use parallel optimizer.")
.def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size,
"Get opt shard group size when not fully use parallel optimizer.")
.def("set_optimizer_weight_shard_integrated_save", &ParallelContext::set_optimizer_weight_shard_integrated_save,
"Set whether to integrated save weight shard when enable parallel optimizer.")
.def("get_optimizer_weight_shard_integrated_save", &ParallelContext::optimizer_weight_shard_integrated_save,
"Get whether to integrated save weight shard when enable parallel optimizer.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")

View File

@ -54,6 +54,8 @@ message ParallelLayouts {
repeated ParamSplitShape param_split_shape = 3;
repeated IndicesOffset indices_offset = 4;
required int32 field = 5;
required int32 opt_weight_shard_step = 6;
required int32 opt_weight_shard_size = 7;
}
message ParallelLayoutItem {

View File

@ -14,11 +14,10 @@
* limitations under the License.
*/
#include "utils/parallel_node_check.h"
#include <set>
#include <string>
#include "utils/parallel_node_check.h"
#include "base/core_ops.h"
namespace mindspore {
@ -32,6 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "Send", "UpdateState", "Load"};
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather};
// clang-format on
bool IsInParallelBlackList(const PrimitivePtr &prim) {
@ -39,6 +39,15 @@ bool IsInParallelBlackList(const PrimitivePtr &prim) {
return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end());
}
bool IsInAllGatherNodeList(const CNodePtr &cnode) {
for (auto &value : ALLGATHER_NODE_LIST_) {
if (IsPrimitiveCNode(cnode, value)) {
return true;
}
}
return false;
}
bool IsParallelConsiderCNode(const CNodePtr &cnode) {
if (cnode == nullptr || cnode->size() == 0) {
return false;
@ -51,9 +60,6 @@ bool IsParallelConsiderCNode(const CNodePtr &cnode) {
if (prim == nullptr) {
return false;
}
if (IsInParallelBlackList(prim)) {
return false;
}
return true;
return !IsInParallelBlackList(prim);
}
} // namespace mindspore

View File

@ -21,6 +21,7 @@
namespace mindspore {
bool IsInParallelBlackList(const PrimitivePtr &);
bool IsInAllGatherNodeList(const CNodePtr &);
bool IsParallelConsiderCNode(const CNodePtr &);
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_

View File

@ -15,6 +15,7 @@
"""Context of auto parallel"""
import threading
import mindspore.context as context
import mindspore.log as logger
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore.parallel._ps_context import _is_role_pserver
from mindspore._c_expression import AutoParallelContext
@ -501,6 +502,48 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_communi_parallel_mode()
def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size):
"""
Set optimizer_weight_shard_size.
Args:
optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel
optimizer across devices.
"""
self.check_context_handle()
if not isinstance(optimizer_weight_shard_size, int):
raise TypeError('optimizer_weight_shard_size is invalid type')
if optimizer_weight_shard_size <= 1:
logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
"Please use the integer larger than 1.")
return
self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)
def get_optimizer_weight_shard_size(self):
"""Get optimizer_weight_shard_size."""
self.check_context_handle()
return self._context_handle.get_optimizer_weight_shard_size()
def set_optimizer_weight_shard_integrated_save(self, optimizer_weight_shard_integrated_save):
"""
Set optimizer_weight_shard_integrated_save.
Args:
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when
enable parallel optimizer.
"""
self.check_context_handle()
if not isinstance(optimizer_weight_shard_integrated_save, bool):
raise TypeError('optimizer_weight_shard_integrated_save is invalid type')
self._context_handle.set_optimizer_weight_shard_integrated_save(optimizer_weight_shard_integrated_save)
def get_optimizer_weight_shard_integrated_save(self):
"""Get optimizer_weight_shard_size."""
self.check_context_handle()
return self._context_handle.get_optimizer_weight_shard_integrated_save()
def reset(self):
"""Reset all settings."""
self.check_context_handle()
@ -540,7 +583,9 @@ _set_auto_parallel_context_func_map = {
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode}
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
"optimizer_weight_shard_integrated_save": auto_parallel_context().set_optimizer_weight_shard_integrated_save}
_get_auto_parallel_context_func_map = {
@ -559,7 +604,9 @@ _get_auto_parallel_context_func_map = {
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode}
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
"optimizer_weight_shard_integrated_save": auto_parallel_context().get_optimizer_weight_shard_integrated_save}
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
@ -567,7 +614,8 @@ _get_auto_parallel_context_func_map = {
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str)
communi_parallel_mode=str, optimizer_weight_shard_size=int,
optimizer_weight_shard_integrated_save=bool)
def _set_auto_parallel_context(**kwargs):
"""
@ -615,7 +663,7 @@ def _set_auto_parallel_context(**kwargs):
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
the devices are distributed alone the pipeline. The total devices will be divided into
'pipeline_stags' stages. This currently could only be used when
parall mode semi_auto_parallel is enabled. Default: 0
parallel mode semi_auto_parallel is enabled. Default: 0
communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
"same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
@ -624,6 +672,11 @@ def _set_auto_parallel_context(**kwargs):
- same_server_group_parallel: Only the communication groups within the same server are parallel.
- no_group_parallel: All communication groups are not parallel.
optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
It should be larger than one and less than or equal with the data parallel size.
Default: -1, which means fully use parallel optimizer in data parallel dimension.
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when enable parallel
optimizer. Default: False.
Raises:
ValueError: If input key is not attribute in auto parallel context.

View File

@ -248,6 +248,10 @@ def _remove_repeated_slices(tensor_layout):
def _infer_rank_list(train_map, predict_map=None):
"""infer checkpoint slices to be loaded"""
ret = {}
if _get_pipeline_stages() > 1:
local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages()))
else:
local_rank = _get_global_rank()
for param_name in train_map:
train_layout = train_map[param_name]
train_dev_mat = train_layout[0]
@ -271,15 +275,13 @@ def _infer_rank_list(train_map, predict_map=None):
dev_num = np.array(predict_layout[0]).prod()
# optimization pass
if _check_same_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = ([dev_rank], True)
ret[param_name] = ([local_rank], True)
continue
if _check_similar_layout(train_layout, predict_layout):
if len(rank_list) == 1:
ret[param_name] = (rank_list, True)
elif len(rank_list) == dev_num:
dev_rank = _get_global_rank()
ret[param_name] = ([rank_list[dev_rank]], True)
ret[param_name] = ([rank_list[local_rank]], True)
else:
ret[param_name] = (rank_list, False)
else:

View File

@ -597,7 +597,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
param_data = allgather_net(param_data)
elif opt_shard_group:
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_integrated_save"):
if allgather_net is None:
allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
@ -1247,7 +1247,9 @@ def _convert_to_list(strategy):
tensor_map = list(layout.tensor_map[0].dim)
param_split_shape = list(layout.param_split_shape[0].dim)
field_size = int(layout.field)
train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size]
shard_stride = int(layout.opt_weight_shard_step)
shard_size = int(layout.opt_weight_shard_size)
train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size, shard_stride, shard_size]
except BaseException as e:
raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.")
return train_map

View File

@ -131,6 +131,17 @@ def test_auto_parallel_momentum_5():
assert not param_dict["weight2"][5]
def test_auto_parallel_momentum_6():
# test not fully use parallel optimizer with optimizer_weight_shard_size
# weight1 could not be shard and weight2 is repeated
context.set_auto_parallel_context(optimizer_weight_shard_size=2)
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
param_dict = train_network.parameter_layout_dict
# validate opt_shard_group
assert param_dict["weight1"][5].startswith("2")
assert param_dict["weight2"][5].startswith("2")
def test_AdamWeightDecay():
""" test_AdamWeightDecay """
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)

View File

@ -59,6 +59,10 @@ def test_set_auto_parallel_context():
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
assert parameter_broadcast_is_set
auto_parallel_context().set_optimizer_weight_shard_integrated_save(True)
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()
assert integrated_save
with pytest.raises(ValueError):
context.set_auto_parallel_context(device_num=0)
@ -105,6 +109,7 @@ def test_reset_auto_parallel_context():
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
stage = auto_parallel_context().get_pipeline_stages()
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()
assert device_num == 1
assert global_rank == 0
@ -116,3 +121,4 @@ def test_reset_auto_parallel_context():
assert not parameter_broadcast_is_set
assert stage == 1
assert communi_parallel_mode == "all_group_parallel"
assert not integrated_save