!37558 Fix AutoParallel Loss Error
Merge pull request !37558 from huangxinjing/fix_loss_error_for_master
This commit is contained in:
commit
d92fd418e3
|
@ -632,6 +632,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
|
|||
MS_LOG(INFO) << "No kernel address";
|
||||
return;
|
||||
}
|
||||
if (!AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index, true)) {
|
||||
return;
|
||||
}
|
||||
auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
|
||||
auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
|
||||
auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
|
||||
|
|
|
@ -2193,7 +2193,6 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_depth) {
|
|||
|
||||
// return -> make_tuple
|
||||
if (current_prim->name() == MAKE_TUPLE) {
|
||||
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
|
||||
return loss_node_info;
|
||||
}
|
||||
|
||||
|
@ -2667,7 +2666,6 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An
|
|||
std::vector<AnfNodePtr> root_forward_nodes;
|
||||
auto loss_cnode = FindLossCNode(graph, 0).loss_node;
|
||||
if (loss_cnode == nullptr) {
|
||||
MS_LOG(WARNING) << "Can not find the loss cnode";
|
||||
return root_forward_nodes;
|
||||
}
|
||||
|
||||
|
|
|
@ -130,15 +130,7 @@ bool NeedUpdate(const CNodePtr &getitem_cnode) {
|
|||
}
|
||||
|
||||
bool WhetherUseDropoutV3(const CNodePtr &dropout, const abstract::ShapePtr &input_shape) {
|
||||
// Only GPT with static shape use DropoutV3
|
||||
auto fullname = dropout->fullname_with_scope();
|
||||
if (fullname.find("PanguAlpha") != std::string::npos && !input_shape->IsDynamic()) {
|
||||
auto shape = input_shape->shape();
|
||||
int64_t shape_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
|
||||
if (shape_size < kV3ShapeLimitSize) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// v3 will cause memory error
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
Parallel Loss for the Parallel Training
|
||||
This is an experimental interface that is subject to change or deletion.
|
||||
"""
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -24,6 +25,7 @@ from mindspore.nn import Cell
|
|||
from mindspore.nn.loss.loss import _check_is_tensor
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages
|
||||
from .layers import _check_input_dtype, _check_input_shape
|
||||
from .op_parallel_config import default_dpmp_config, OpParallelConfig
|
||||
|
||||
|
@ -187,7 +189,12 @@ class CrossEntropyLoss(Cell):
|
|||
", but got the type: {}.".format(type(parallel_config)))
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
self.add = P.Add().shard(((dp, mp), (1,)))
|
||||
self.enable_force_redistribute = False
|
||||
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
|
||||
self.enable_force_redistribute = True
|
||||
self.add = P.Add().shard(((dp, mp), ())).add_prim_attr("keep_alive", True)
|
||||
self.add_label = P.Add().shard(((dp,), ())).add_prim_attr("keep_alive", True)
|
||||
self._check_and_modify_sharding_context(dp)
|
||||
self.sum2 = P.ReduceSum().shard(((1,),))
|
||||
self.mul2 = P.Mul().shard(((1,), (1,)))
|
||||
self.add2 = P.Add()
|
||||
|
@ -197,12 +204,20 @@ class CrossEntropyLoss(Cell):
|
|||
self._softmax = _Softmax(parallel_config)
|
||||
self._nllloss = _NLLLoss(parallel_config)
|
||||
|
||||
@staticmethod
|
||||
def _check_and_modify_sharding_context(dp):
|
||||
device_num = _get_device_num()
|
||||
stages = _get_pipeline_stages()
|
||||
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and dp * stages != device_num:
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
|
||||
def construct(self, logits, label, input_mask):
|
||||
self._check_input(logits, label, input_mask)
|
||||
|
||||
# The add is used for forcing the redistribution before stepping in sub graphs, when semi/auto parallel enabled.
|
||||
# After relu, the following case should be euqal to add(logits, 0)
|
||||
logits = self.add(logits, F.cast(self.relu(F.tuple_to_array((-1e-32,))), F.dtype(logits)))
|
||||
if self.enable_force_redistribute:
|
||||
logits = self.add(logits, 0)
|
||||
label = self.add_label(label, 0)
|
||||
softmax, one_hot_label = self._softmax(logits, label)
|
||||
loss_reduce = self._nllloss(softmax, one_hot_label)
|
||||
|
||||
|
|
|
@ -2211,7 +2211,6 @@ def _get_lambda_func(total_layer=None):
|
|||
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
|
||||
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
|
||||
network.pipeline_stage = pp_id
|
||||
logger.info(f"pipeline stage id is {pp_id}")
|
||||
|
||||
# Used for optimizer's fusion tag
|
||||
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
|
||||
|
|
Loading…
Reference in New Issue