From 4aae231a8a65b68afb3c9b6f6fc6072ea8427631 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Sat, 10 Jul 2021 14:56:49 +0800 Subject: [PATCH] fix_pipeline_opt_shard --- .../graph_util/pipeline_split_utils.cc | 5 ++++ .../pipeline_transformer.cc | 5 +++- .../ccsrc/frontend/parallel/step_parallel.cc | 3 ++- mindspore/parallel/_utils.py | 10 ++++--- mindspore/train/model.py | 2 +- .../official/nlp/pangu_alpha/src/utils.py | 26 +++++++++++-------- 6 files changed, 33 insertions(+), 18 deletions(-) mode change 100755 => 100644 mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc old mode 100755 new mode 100644 index 5f93d5ee8c9..e05802efd66 --- a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc @@ -98,6 +98,11 @@ void InsertVirtualAssignAdd(const std::pair &node_user, const F if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) { return; } + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer(); + if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer) { + return; + } auto prim = GetCNodePrimitive(cnode); if (prim == nullptr) { MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd."; diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index d4c6f63347a..77c3e3f8a3c 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -43,7 +43,7 @@ static std::unordered_map> parameter_color_map; // map static std::unordered_map send_tag_map; static std::unordered_map recv_tag_map; -const std::set WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem}; +const std::set WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem, prim::kPrimMakeTuple}; static bool IsInWhiteList(const CNodePtr &cnode) { for (auto &prim : WHITE_LIST) { @@ -97,6 +97,9 @@ bool PipelineTransformer::NeedGrad(const CNodePtr &cnode) { if (load->input(1)->isa() && ParameterRequireGrad(load->input(1))) { return true; } + if (IsPrimitiveCNode(input, prim::kPrimCast)) { + return NeedGrad(input->cast()); + } } } return false; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index ccb16a5ebc5..de3854c976f 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1763,7 +1763,8 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & for (auto ¶m_pair : param_sub_set) { auto cnode = param_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive)) { + if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive) && + !IsPrimitiveCNode(cnode, prim::kPrimDepend)) { OperatorInfoPtr distribute_operator = cnode->user_data(); if (distribute_operator == nullptr) { MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr"; diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 23c3112dae2..76221fc5287 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -79,16 +79,18 @@ def _to_full_shapes(shapes, device_num): return new_shapes -def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): +def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None): """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data from host solution. """ lst = [] + device_num = global_device_num // _get_pipeline_stages() + stage_rank = global_rank % device_num if not isinstance(elem, (tuple, list)): elem = [elem] - if global_rank >= device_num: + if stage_rank >= device_num: raise ValueError("The global rank must be smaller than device number, the global rank is {}, " - "the device num is {}".format(global_rank, device_num)) + "the device num is {}".format(stage_rank, device_num)) for data in elem: if isinstance(data, np.ndarray): @@ -106,7 +108,7 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): else: new_shape += (item,) new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_)) - start = global_rank * batchsize_per_device + start = stage_rank * batchsize_per_device new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy() new_tensor = Tensor(new_tensor_numpy) lst.append(new_tensor) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 65fa7f2b04d..d87ec722425 100644 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -889,7 +889,7 @@ class Model: >>> loss_scale_manager = FixedLossScaleManager() >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) - >>> model.infer_train_layout(dataset) + >>> layout_dict = model.infer_train_layout(dataset) """ self._infer_train_check(train_dataset, dataset_sink_mode, sink_size) diff --git a/model_zoo/official/nlp/pangu_alpha/src/utils.py b/model_zoo/official/nlp/pangu_alpha/src/utils.py index 8449d4dcd8e..29a26557fc2 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/utils.py +++ b/model_zoo/official/nlp/pangu_alpha/src/utils.py @@ -20,6 +20,7 @@ import os import time import numpy as np import mindspore.nn as nn +from mindspore import context from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.ops import functional as F @@ -137,8 +138,11 @@ class GlobalNorm(nn.Cell): self.hyper_map = C.HyperMap() self.is_pipeline = (config.stage_num > 1) if self.is_pipeline: - group_size = config.mp - group_list, group_name = _get_model_parallel_group(config.mp) + if context.get_auto_parallel_context("enable_parallel_optimizer"): + group_size = get_group_size() // config.stage_num + else: + group_size = config.mp + group_list, group_name = _get_model_parallel_group(group_size) create_group(group_name, group_list) self.allreduce = P.AllReduce(group=group_name) pipeline_group_list, pipeline_group_name = _get_pipeline_group() @@ -146,18 +150,18 @@ class GlobalNorm(nn.Cell): self.allreduce2 = P.AllReduce(group=pipeline_group_name) else: group_size = get_group_size() - if config.word_emb_dp: - self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name - and "embedding_table" not in x.name for x in params) - else: - self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name - and "position_embedding.embedding_table" not in x.name for x in params) self.allreduce_group_size = () - for item in self.allreduce_filter: - if item: + for x in params: + if "projection.bias" not in x.name and "layernorm" not in x.name and "embedding_table" not in x.name: self.allreduce_group_size = self.allreduce_group_size + (1.0,) - else: + elif "embedding_table" not in x.name: self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,) + else: + if not config.word_emb_dp and "position_embedding.embedding_table" not in x.name \ + and "top_query_embedding_table" not in x.name: + self.allreduce_group_size = self.allreduce_group_size + (config.dp * 1.0,) + else: + self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,) def construct(self, grads): """Calculate global norm construct"""