!37558 Fix AutoParallel Loss Error
Merge pull request !37558 from huangxinjing/fix_loss_error_for_master
This commit is contained in:
commit
d92fd418e3
|
@ -632,6 +632,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
|
||||||
MS_LOG(INFO) << "No kernel address";
|
MS_LOG(INFO) << "No kernel address";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (!AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index, true)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
|
auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
|
||||||
auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
|
auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
|
||||||
auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
|
auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
|
||||||
|
|
|
@ -2193,7 +2193,6 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_depth) {
|
||||||
|
|
||||||
// return -> make_tuple
|
// return -> make_tuple
|
||||||
if (current_prim->name() == MAKE_TUPLE) {
|
if (current_prim->name() == MAKE_TUPLE) {
|
||||||
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
|
|
||||||
return loss_node_info;
|
return loss_node_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2667,7 +2666,6 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An
|
||||||
std::vector<AnfNodePtr> root_forward_nodes;
|
std::vector<AnfNodePtr> root_forward_nodes;
|
||||||
auto loss_cnode = FindLossCNode(graph, 0).loss_node;
|
auto loss_cnode = FindLossCNode(graph, 0).loss_node;
|
||||||
if (loss_cnode == nullptr) {
|
if (loss_cnode == nullptr) {
|
||||||
MS_LOG(WARNING) << "Can not find the loss cnode";
|
|
||||||
return root_forward_nodes;
|
return root_forward_nodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -130,15 +130,7 @@ bool NeedUpdate(const CNodePtr &getitem_cnode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool WhetherUseDropoutV3(const CNodePtr &dropout, const abstract::ShapePtr &input_shape) {
|
bool WhetherUseDropoutV3(const CNodePtr &dropout, const abstract::ShapePtr &input_shape) {
|
||||||
// Only GPT with static shape use DropoutV3
|
// v3 will cause memory error
|
||||||
auto fullname = dropout->fullname_with_scope();
|
|
||||||
if (fullname.find("PanguAlpha") != std::string::npos && !input_shape->IsDynamic()) {
|
|
||||||
auto shape = input_shape->shape();
|
|
||||||
int64_t shape_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
|
|
||||||
if (shape_size < kV3ShapeLimitSize) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
Parallel Loss for the Parallel Training
|
Parallel Loss for the Parallel Training
|
||||||
This is an experimental interface that is subject to change or deletion.
|
This is an experimental interface that is subject to change or deletion.
|
||||||
"""
|
"""
|
||||||
|
from mindspore.parallel import set_algo_parameters
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
@ -24,6 +25,7 @@ from mindspore.nn import Cell
|
||||||
from mindspore.nn.loss.loss import _check_is_tensor
|
from mindspore.nn.loss.loss import _check_is_tensor
|
||||||
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages
|
||||||
from .layers import _check_input_dtype, _check_input_shape
|
from .layers import _check_input_dtype, _check_input_shape
|
||||||
from .op_parallel_config import default_dpmp_config, OpParallelConfig
|
from .op_parallel_config import default_dpmp_config, OpParallelConfig
|
||||||
|
|
||||||
|
@ -187,7 +189,12 @@ class CrossEntropyLoss(Cell):
|
||||||
", but got the type: {}.".format(type(parallel_config)))
|
", but got the type: {}.".format(type(parallel_config)))
|
||||||
dp = parallel_config.data_parallel
|
dp = parallel_config.data_parallel
|
||||||
mp = parallel_config.model_parallel
|
mp = parallel_config.model_parallel
|
||||||
self.add = P.Add().shard(((dp, mp), (1,)))
|
self.enable_force_redistribute = False
|
||||||
|
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
|
||||||
|
self.enable_force_redistribute = True
|
||||||
|
self.add = P.Add().shard(((dp, mp), ())).add_prim_attr("keep_alive", True)
|
||||||
|
self.add_label = P.Add().shard(((dp,), ())).add_prim_attr("keep_alive", True)
|
||||||
|
self._check_and_modify_sharding_context(dp)
|
||||||
self.sum2 = P.ReduceSum().shard(((1,),))
|
self.sum2 = P.ReduceSum().shard(((1,),))
|
||||||
self.mul2 = P.Mul().shard(((1,), (1,)))
|
self.mul2 = P.Mul().shard(((1,), (1,)))
|
||||||
self.add2 = P.Add()
|
self.add2 = P.Add()
|
||||||
|
@ -197,12 +204,20 @@ class CrossEntropyLoss(Cell):
|
||||||
self._softmax = _Softmax(parallel_config)
|
self._softmax = _Softmax(parallel_config)
|
||||||
self._nllloss = _NLLLoss(parallel_config)
|
self._nllloss = _NLLLoss(parallel_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_and_modify_sharding_context(dp):
|
||||||
|
device_num = _get_device_num()
|
||||||
|
stages = _get_pipeline_stages()
|
||||||
|
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and dp * stages != device_num:
|
||||||
|
set_algo_parameters(fully_use_devices=False)
|
||||||
|
|
||||||
def construct(self, logits, label, input_mask):
|
def construct(self, logits, label, input_mask):
|
||||||
self._check_input(logits, label, input_mask)
|
self._check_input(logits, label, input_mask)
|
||||||
|
|
||||||
# The add is used for forcing the redistribution before stepping in sub graphs, when semi/auto parallel enabled.
|
# The add is used for forcing the redistribution before stepping in sub graphs, when semi/auto parallel enabled.
|
||||||
# After relu, the following case should be euqal to add(logits, 0)
|
if self.enable_force_redistribute:
|
||||||
logits = self.add(logits, F.cast(self.relu(F.tuple_to_array((-1e-32,))), F.dtype(logits)))
|
logits = self.add(logits, 0)
|
||||||
|
label = self.add_label(label, 0)
|
||||||
softmax, one_hot_label = self._softmax(logits, label)
|
softmax, one_hot_label = self._softmax(logits, label)
|
||||||
loss_reduce = self._nllloss(softmax, one_hot_label)
|
loss_reduce = self._nllloss(softmax, one_hot_label)
|
||||||
|
|
||||||
|
|
|
@ -2211,7 +2211,6 @@ def _get_lambda_func(total_layer=None):
|
||||||
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
|
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
|
||||||
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
|
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
|
||||||
network.pipeline_stage = pp_id
|
network.pipeline_stage = pp_id
|
||||||
logger.info(f"pipeline stage id is {pp_id}")
|
|
||||||
|
|
||||||
# Used for optimizer's fusion tag
|
# Used for optimizer's fusion tag
|
||||||
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
|
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
|
||||||
|
|
Loading…
Reference in New Issue