forked from mindspore-Ecosystem/mindspore
fix bugs of moe: only use a fewer dp in moe
This commit is contained in:
parent
cd3cfc3320
commit
93da6bab46
|
@ -23,7 +23,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.nn import Cell
|
||||
from mindspore.nn.loss.loss import _check_is_tensor
|
||||
from .layers import _check_input_dtype, _check_input_shape
|
||||
from .op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
|
||||
from .op_parallel_config import default_dpmp_config, OpParallelConfig
|
||||
|
||||
__all__ = ["CrossEntropyLoss"]
|
||||
|
||||
|
@ -33,7 +33,7 @@ class CrossEntropyLoss(Cell):
|
|||
Calculate the cross entropy loss.
|
||||
|
||||
Args:
|
||||
parallel_config (OpParallelConfig, MoEParallelConfig): The parallel configure. Default `default_dpmp_config`,
|
||||
parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
|
||||
an instance of `OpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
|
@ -65,16 +65,13 @@ class CrossEntropyLoss(Cell):
|
|||
|
||||
def __init__(self, parallel_config=default_dpmp_config):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
if not isinstance(parallel_config, OpParallelConfig) and not isinstance(parallel_config, MoEParallelConfig):
|
||||
if not isinstance(parallel_config, OpParallelConfig) and not isinstance(parallel_config):
|
||||
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
|
||||
" or MoEParallelConfig, but got the type: {}.".format(type(parallel_config)))
|
||||
", but got the type: {}.".format(type(parallel_config)))
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
self.sum = P.ReduceSum().shard(((dp, mp),))
|
||||
if isinstance(parallel_config, MoEParallelConfig):
|
||||
self.onehot = P.OneHot().shard(((dp, mp*parallel_config.expert_parallel), (), ()))
|
||||
else:
|
||||
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
|
||||
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
|
||||
# on/off value for onehot, for smooth labeling, modify the off_value
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
|
|
@ -94,10 +94,12 @@ def _check_moe_config(moe_config=None, parallel_config=None):
|
|||
if device_num % parallel_config.expert_parallel != 0:
|
||||
raise ValueError(f"device_num: {device_num} should be a multiple of expert_parallel: "
|
||||
f"{parallel_config.expert_parallel}.")
|
||||
if parallel_config.data_parallel * parallel_config.model_parallel * parallel_config.expert_parallel > device_num:
|
||||
raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel}, "
|
||||
f"model parallel: {parallel_config.model_parallel}, and "
|
||||
f"expert parallel: {parallel_config.expert_parallel} "
|
||||
if parallel_config.data_parallel % parallel_config.expert_parallel != 0:
|
||||
raise ValueError(f"data parallel: {parallel_config.data_parallel} should be a multiple of "
|
||||
f"expert_parallel: {parallel_config.expert_parallel}.")
|
||||
if parallel_config.data_parallel * parallel_config.model_parallel > device_num:
|
||||
raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel} and "
|
||||
f"model parallel: {parallel_config.model_parallel} "
|
||||
f"should be less than device_num: {device_num}.")
|
||||
|
||||
|
||||
|
|
|
@ -406,12 +406,13 @@ class FeedForward(Cell):
|
|||
parallel_config=default_dpmp_config):
|
||||
super(FeedForward, self).__init__()
|
||||
_check_config(parallel_config)
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
if expert_num > 1:
|
||||
ep = parallel_config.expert_parallel
|
||||
else:
|
||||
ep = 1
|
||||
# ffn use less dp than other ops when use_moe, due to there are ops use dp and ep.
|
||||
dp = int(parallel_config.data_parallel / ep)
|
||||
if ffn_hidden_size % mp != 0:
|
||||
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the"
|
||||
"num of "
|
||||
|
|
|
@ -153,7 +153,7 @@ class TransformerNet(nn.Cell):
|
|||
ffn_hidden_size=64,
|
||||
moe_config=moe_config,
|
||||
parallel_config=parallel_config)
|
||||
self.loss = CrossEntropyLoss(parallel_config=parallel_config.moe_parallel_config)
|
||||
self.loss = CrossEntropyLoss(parallel_config=parallel_config.dp_mp_config)
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5, y, mask):
|
||||
predict, _, _ = self.network(x1, x2, x3, x4, x5)
|
||||
|
@ -212,7 +212,8 @@ def test_moe_expert_parallel3():
|
|||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=2)
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
with pytest.raises(ValueError):
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
||||
def test_moe_expert_parallel_exception():
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue