fix_pipeline_opt_shard
This commit is contained in:
parent
2d11b5ea43
commit
bb94a9e811
5
mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc
Executable file → Normal file
5
mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc
Executable file → Normal file
|
@ -98,6 +98,11 @@ void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const F
|
|||
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer) {
|
||||
return;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd.";
|
||||
|
|
|
@ -43,7 +43,7 @@ static std::unordered_map<AnfNodePtr, std::set<int64_t>> parameter_color_map;
|
|||
// map<rank, tag>
|
||||
static std::unordered_map<int64_t, int64_t> send_tag_map;
|
||||
static std::unordered_map<int64_t, int64_t> recv_tag_map;
|
||||
const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem};
|
||||
const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem, prim::kPrimMakeTuple};
|
||||
|
||||
static bool IsInWhiteList(const CNodePtr &cnode) {
|
||||
for (auto &prim : WHITE_LIST) {
|
||||
|
@ -97,6 +97,9 @@ bool PipelineTransformer::NeedGrad(const CNodePtr &cnode) {
|
|||
if (load->input(1)->isa<Parameter>() && ParameterRequireGrad(load->input(1))) {
|
||||
return true;
|
||||
}
|
||||
if (IsPrimitiveCNode(input, prim::kPrimCast)) {
|
||||
return NeedGrad(input->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -1763,7 +1763,8 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|||
for (auto ¶m_pair : param_sub_set) {
|
||||
auto cnode = param_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
|
||||
if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive) &&
|
||||
!IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
|
||||
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
|
||||
if (distribute_operator == nullptr) {
|
||||
MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";
|
||||
|
|
|
@ -79,16 +79,18 @@ def _to_full_shapes(shapes, device_num):
|
|||
return new_shapes
|
||||
|
||||
|
||||
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
|
||||
def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None):
|
||||
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
|
||||
from host solution.
|
||||
"""
|
||||
lst = []
|
||||
device_num = global_device_num // _get_pipeline_stages()
|
||||
stage_rank = global_rank % device_num
|
||||
if not isinstance(elem, (tuple, list)):
|
||||
elem = [elem]
|
||||
if global_rank >= device_num:
|
||||
if stage_rank >= device_num:
|
||||
raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
|
||||
"the device num is {}".format(global_rank, device_num))
|
||||
"the device num is {}".format(stage_rank, device_num))
|
||||
|
||||
for data in elem:
|
||||
if isinstance(data, np.ndarray):
|
||||
|
@ -106,7 +108,7 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
|
|||
else:
|
||||
new_shape += (item,)
|
||||
new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
|
||||
start = global_rank * batchsize_per_device
|
||||
start = stage_rank * batchsize_per_device
|
||||
new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
|
||||
new_tensor = Tensor(new_tensor_numpy)
|
||||
lst.append(new_tensor)
|
||||
|
|
|
@ -889,7 +889,7 @@ class Model:
|
|||
>>> loss_scale_manager = FixedLossScaleManager()
|
||||
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.infer_train_layout(dataset)
|
||||
>>> layout_dict = model.infer_train_layout(dataset)
|
||||
"""
|
||||
self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import os
|
|||
import time
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -137,8 +138,11 @@ class GlobalNorm(nn.Cell):
|
|||
self.hyper_map = C.HyperMap()
|
||||
self.is_pipeline = (config.stage_num > 1)
|
||||
if self.is_pipeline:
|
||||
group_size = config.mp
|
||||
group_list, group_name = _get_model_parallel_group(config.mp)
|
||||
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
||||
group_size = get_group_size() // config.stage_num
|
||||
else:
|
||||
group_size = config.mp
|
||||
group_list, group_name = _get_model_parallel_group(group_size)
|
||||
create_group(group_name, group_list)
|
||||
self.allreduce = P.AllReduce(group=group_name)
|
||||
pipeline_group_list, pipeline_group_name = _get_pipeline_group()
|
||||
|
@ -146,18 +150,18 @@ class GlobalNorm(nn.Cell):
|
|||
self.allreduce2 = P.AllReduce(group=pipeline_group_name)
|
||||
else:
|
||||
group_size = get_group_size()
|
||||
if config.word_emb_dp:
|
||||
self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name
|
||||
and "embedding_table" not in x.name for x in params)
|
||||
else:
|
||||
self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name
|
||||
and "position_embedding.embedding_table" not in x.name for x in params)
|
||||
self.allreduce_group_size = ()
|
||||
for item in self.allreduce_filter:
|
||||
if item:
|
||||
for x in params:
|
||||
if "projection.bias" not in x.name and "layernorm" not in x.name and "embedding_table" not in x.name:
|
||||
self.allreduce_group_size = self.allreduce_group_size + (1.0,)
|
||||
else:
|
||||
elif "embedding_table" not in x.name:
|
||||
self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,)
|
||||
else:
|
||||
if not config.word_emb_dp and "position_embedding.embedding_table" not in x.name \
|
||||
and "top_query_embedding_table" not in x.name:
|
||||
self.allreduce_group_size = self.allreduce_group_size + (config.dp * 1.0,)
|
||||
else:
|
||||
self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,)
|
||||
|
||||
def construct(self, grads):
|
||||
"""Calculate global norm construct"""
|
||||
|
|
Loading…
Reference in New Issue