!37558 Fix AutoParallel Loss Error

Merge pull request !37558 from huangxinjing/fix_loss_error_for_master
This commit is contained in:
i-robot 2022-07-11 01:38:22 +00:00 committed by Gitee
commit d92fd418e3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 22 additions and 15 deletions

View File

@ -632,6 +632,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
MS_LOG(INFO) << "No kernel address"; MS_LOG(INFO) << "No kernel address";
return; return;
} }
if (!AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index, true)) {
return;
}
auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);

View File

@ -2193,7 +2193,6 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_depth) {
// return -> make_tuple // return -> make_tuple
if (current_prim->name() == MAKE_TUPLE) { if (current_prim->name() == MAKE_TUPLE) {
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
return loss_node_info; return loss_node_info;
} }
@ -2667,7 +2666,6 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An
std::vector<AnfNodePtr> root_forward_nodes; std::vector<AnfNodePtr> root_forward_nodes;
auto loss_cnode = FindLossCNode(graph, 0).loss_node; auto loss_cnode = FindLossCNode(graph, 0).loss_node;
if (loss_cnode == nullptr) { if (loss_cnode == nullptr) {
MS_LOG(WARNING) << "Can not find the loss cnode";
return root_forward_nodes; return root_forward_nodes;
} }

View File

@ -130,15 +130,7 @@ bool NeedUpdate(const CNodePtr &getitem_cnode) {
} }
bool WhetherUseDropoutV3(const CNodePtr &dropout, const abstract::ShapePtr &input_shape) { bool WhetherUseDropoutV3(const CNodePtr &dropout, const abstract::ShapePtr &input_shape) {
// Only GPT with static shape use DropoutV3 // v3 will cause memory error
auto fullname = dropout->fullname_with_scope();
if (fullname.find("PanguAlpha") != std::string::npos && !input_shape->IsDynamic()) {
auto shape = input_shape->shape();
int64_t shape_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
if (shape_size < kV3ShapeLimitSize) {
return true;
}
}
return false; return false;
} }

View File

@ -16,6 +16,7 @@
Parallel Loss for the Parallel Training Parallel Loss for the Parallel Training
This is an experimental interface that is subject to change or deletion. This is an experimental interface that is subject to change or deletion.
""" """
from mindspore.parallel import set_algo_parameters
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -24,6 +25,7 @@ from mindspore.nn import Cell
from mindspore.nn.loss.loss import _check_is_tensor from mindspore.nn.loss.loss import _check_is_tensor
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages
from .layers import _check_input_dtype, _check_input_shape from .layers import _check_input_dtype, _check_input_shape
from .op_parallel_config import default_dpmp_config, OpParallelConfig from .op_parallel_config import default_dpmp_config, OpParallelConfig
@ -187,7 +189,12 @@ class CrossEntropyLoss(Cell):
", but got the type: {}.".format(type(parallel_config))) ", but got the type: {}.".format(type(parallel_config)))
dp = parallel_config.data_parallel dp = parallel_config.data_parallel
mp = parallel_config.model_parallel mp = parallel_config.model_parallel
self.add = P.Add().shard(((dp, mp), (1,))) self.enable_force_redistribute = False
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
self.enable_force_redistribute = True
self.add = P.Add().shard(((dp, mp), ())).add_prim_attr("keep_alive", True)
self.add_label = P.Add().shard(((dp,), ())).add_prim_attr("keep_alive", True)
self._check_and_modify_sharding_context(dp)
self.sum2 = P.ReduceSum().shard(((1,),)) self.sum2 = P.ReduceSum().shard(((1,),))
self.mul2 = P.Mul().shard(((1,), (1,))) self.mul2 = P.Mul().shard(((1,), (1,)))
self.add2 = P.Add() self.add2 = P.Add()
@ -197,12 +204,20 @@ class CrossEntropyLoss(Cell):
self._softmax = _Softmax(parallel_config) self._softmax = _Softmax(parallel_config)
self._nllloss = _NLLLoss(parallel_config) self._nllloss = _NLLLoss(parallel_config)
@staticmethod
def _check_and_modify_sharding_context(dp):
device_num = _get_device_num()
stages = _get_pipeline_stages()
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and dp * stages != device_num:
set_algo_parameters(fully_use_devices=False)
def construct(self, logits, label, input_mask): def construct(self, logits, label, input_mask):
self._check_input(logits, label, input_mask) self._check_input(logits, label, input_mask)
# The add is used for forcing the redistribution before stepping in sub graphs, when semi/auto parallel enabled. # The add is used for forcing the redistribution before stepping in sub graphs, when semi/auto parallel enabled.
# After relu, the following case should be euqal to add(logits, 0) if self.enable_force_redistribute:
logits = self.add(logits, F.cast(self.relu(F.tuple_to_array((-1e-32,))), F.dtype(logits))) logits = self.add(logits, 0)
label = self.add_label(label, 0)
softmax, one_hot_label = self._softmax(logits, label) softmax, one_hot_label = self._softmax(logits, label)
loss_reduce = self._nllloss(softmax, one_hot_label) loss_reduce = self._nllloss(softmax, one_hot_label)

View File

@ -2211,7 +2211,6 @@ def _get_lambda_func(total_layer=None):
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1] # the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1) pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
network.pipeline_stage = pp_id network.pipeline_stage = pp_id
logger.info(f"pipeline stage id is {pp_id}")
# Used for optimizer's fusion tag # Used for optimizer's fusion tag
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1) dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)