!37558 Fix AutoParallel Loss Error

Merge pull request !37558 from huangxinjing/fix_loss_error_for_master
This commit is contained in:
i-robot 2022-07-11 01:38:22 +00:00 committed by Gitee
commit d92fd418e3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 22 additions and 15 deletions

View File

@ -632,6 +632,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
MS_LOG(INFO) << "No kernel address";
return;
}
if (!AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index, true)) {
return;
}
auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);

View File

@ -2193,7 +2193,6 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_depth) {
// return -> make_tuple
if (current_prim->name() == MAKE_TUPLE) {
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
return loss_node_info;
}
@ -2667,7 +2666,6 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An
std::vector<AnfNodePtr> root_forward_nodes;
auto loss_cnode = FindLossCNode(graph, 0).loss_node;
if (loss_cnode == nullptr) {
MS_LOG(WARNING) << "Can not find the loss cnode";
return root_forward_nodes;
}

View File

@ -130,15 +130,7 @@ bool NeedUpdate(const CNodePtr &getitem_cnode) {
}
bool WhetherUseDropoutV3(const CNodePtr &dropout, const abstract::ShapePtr &input_shape) {
// Only GPT with static shape use DropoutV3
auto fullname = dropout->fullname_with_scope();
if (fullname.find("PanguAlpha") != std::string::npos && !input_shape->IsDynamic()) {
auto shape = input_shape->shape();
int64_t shape_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
if (shape_size < kV3ShapeLimitSize) {
return true;
}
}
// v3 will cause memory error
return false;
}

View File

@ -16,6 +16,7 @@
Parallel Loss for the Parallel Training
This is an experimental interface that is subject to change or deletion.
"""
from mindspore.parallel import set_algo_parameters
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
@ -24,6 +25,7 @@ from mindspore.nn import Cell
from mindspore.nn.loss.loss import _check_is_tensor
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages
from .layers import _check_input_dtype, _check_input_shape
from .op_parallel_config import default_dpmp_config, OpParallelConfig
@ -187,7 +189,12 @@ class CrossEntropyLoss(Cell):
", but got the type: {}.".format(type(parallel_config)))
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
self.add = P.Add().shard(((dp, mp), (1,)))
self.enable_force_redistribute = False
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
self.enable_force_redistribute = True
self.add = P.Add().shard(((dp, mp), ())).add_prim_attr("keep_alive", True)
self.add_label = P.Add().shard(((dp,), ())).add_prim_attr("keep_alive", True)
self._check_and_modify_sharding_context(dp)
self.sum2 = P.ReduceSum().shard(((1,),))
self.mul2 = P.Mul().shard(((1,), (1,)))
self.add2 = P.Add()
@ -197,12 +204,20 @@ class CrossEntropyLoss(Cell):
self._softmax = _Softmax(parallel_config)
self._nllloss = _NLLLoss(parallel_config)
@staticmethod
def _check_and_modify_sharding_context(dp):
device_num = _get_device_num()
stages = _get_pipeline_stages()
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and dp * stages != device_num:
set_algo_parameters(fully_use_devices=False)
def construct(self, logits, label, input_mask):
self._check_input(logits, label, input_mask)
# The add is used for forcing the redistribution before stepping in sub graphs, when semi/auto parallel enabled.
# After relu, the following case should be euqal to add(logits, 0)
logits = self.add(logits, F.cast(self.relu(F.tuple_to_array((-1e-32,))), F.dtype(logits)))
if self.enable_force_redistribute:
logits = self.add(logits, 0)
label = self.add_label(label, 0)
softmax, one_hot_label = self._softmax(logits, label)
loss_reduce = self._nllloss(softmax, one_hot_label)

View File

@ -2211,7 +2211,6 @@ def _get_lambda_func(total_layer=None):
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
network.pipeline_stage = pp_id
logger.info(f"pipeline stage id is {pp_id}")
# Used for optimizer's fusion tag
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)