fix_pipeline_opt_shard

This commit is contained in:
yao_yf 2021-07-10 14:56:49 +08:00
parent 2d11b5ea43
commit bb94a9e811
6 changed files with 33 additions and 18 deletions

View File

@ -98,6 +98,11 @@ void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const F
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
return;
}
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer) {
return;
}
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd.";

View File

@ -43,7 +43,7 @@ static std::unordered_map<AnfNodePtr, std::set<int64_t>> parameter_color_map;
// map<rank, tag>
static std::unordered_map<int64_t, int64_t> send_tag_map;
static std::unordered_map<int64_t, int64_t> recv_tag_map;
const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem};
const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem, prim::kPrimMakeTuple};
static bool IsInWhiteList(const CNodePtr &cnode) {
for (auto &prim : WHITE_LIST) {
@ -97,6 +97,9 @@ bool PipelineTransformer::NeedGrad(const CNodePtr &cnode) {
if (load->input(1)->isa<Parameter>() && ParameterRequireGrad(load->input(1))) {
return true;
}
if (IsPrimitiveCNode(input, prim::kPrimCast)) {
return NeedGrad(input->cast<CNodePtr>());
}
}
}
return false;

View File

@ -1763,7 +1763,8 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
for (auto &param_pair : param_sub_set) {
auto cnode = param_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive) &&
!IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";

View File

@ -79,16 +79,18 @@ def _to_full_shapes(shapes, device_num):
return new_shapes
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None):
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
from host solution.
"""
lst = []
device_num = global_device_num // _get_pipeline_stages()
stage_rank = global_rank % device_num
if not isinstance(elem, (tuple, list)):
elem = [elem]
if global_rank >= device_num:
if stage_rank >= device_num:
raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
"the device num is {}".format(global_rank, device_num))
"the device num is {}".format(stage_rank, device_num))
for data in elem:
if isinstance(data, np.ndarray):
@ -106,7 +108,7 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
else:
new_shape += (item,)
new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
start = global_rank * batchsize_per_device
start = stage_rank * batchsize_per_device
new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
new_tensor = Tensor(new_tensor_numpy)
lst.append(new_tensor)

View File

@ -889,7 +889,7 @@ class Model:
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.infer_train_layout(dataset)
>>> layout_dict = model.infer_train_layout(dataset)
"""
self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)

View File

@ -20,6 +20,7 @@ import os
import time
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
@ -137,8 +138,11 @@ class GlobalNorm(nn.Cell):
self.hyper_map = C.HyperMap()
self.is_pipeline = (config.stage_num > 1)
if self.is_pipeline:
group_size = config.mp
group_list, group_name = _get_model_parallel_group(config.mp)
if context.get_auto_parallel_context("enable_parallel_optimizer"):
group_size = get_group_size() // config.stage_num
else:
group_size = config.mp
group_list, group_name = _get_model_parallel_group(group_size)
create_group(group_name, group_list)
self.allreduce = P.AllReduce(group=group_name)
pipeline_group_list, pipeline_group_name = _get_pipeline_group()
@ -146,18 +150,18 @@ class GlobalNorm(nn.Cell):
self.allreduce2 = P.AllReduce(group=pipeline_group_name)
else:
group_size = get_group_size()
if config.word_emb_dp:
self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name
and "embedding_table" not in x.name for x in params)
else:
self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name
and "position_embedding.embedding_table" not in x.name for x in params)
self.allreduce_group_size = ()
for item in self.allreduce_filter:
if item:
for x in params:
if "projection.bias" not in x.name and "layernorm" not in x.name and "embedding_table" not in x.name:
self.allreduce_group_size = self.allreduce_group_size + (1.0,)
else:
elif "embedding_table" not in x.name:
self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,)
else:
if not config.word_emb_dp and "position_embedding.embedding_table" not in x.name \
and "top_query_embedding_table" not in x.name:
self.allreduce_group_size = self.allreduce_group_size + (config.dp * 1.0,)
else:
self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,)
def construct(self, grads):
"""Calculate global norm construct"""