fix bugs of moe: only use a fewer dp in moe

This commit is contained in:
b00518648 2022-03-26 16:16:36 +08:00
parent cd3cfc3320
commit 93da6bab46
4 changed files with 16 additions and 15 deletions

View File

@ -23,7 +23,7 @@ from mindspore.ops import functional as F
from mindspore.nn import Cell
from mindspore.nn.loss.loss import _check_is_tensor
from .layers import _check_input_dtype, _check_input_shape
from .op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
from .op_parallel_config import default_dpmp_config, OpParallelConfig
__all__ = ["CrossEntropyLoss"]
@ -33,7 +33,7 @@ class CrossEntropyLoss(Cell):
Calculate the cross entropy loss.
Args:
parallel_config (OpParallelConfig, MoEParallelConfig): The parallel configure. Default `default_dpmp_config`,
parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
Inputs:
@ -65,16 +65,13 @@ class CrossEntropyLoss(Cell):
def __init__(self, parallel_config=default_dpmp_config):
super(CrossEntropyLoss, self).__init__()
if not isinstance(parallel_config, OpParallelConfig) and not isinstance(parallel_config, MoEParallelConfig):
if not isinstance(parallel_config, OpParallelConfig) and not isinstance(parallel_config):
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
" or MoEParallelConfig, but got the type: {}.".format(type(parallel_config)))
", but got the type: {}.".format(type(parallel_config)))
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
self.sum = P.ReduceSum().shard(((dp, mp),))
if isinstance(parallel_config, MoEParallelConfig):
self.onehot = P.OneHot().shard(((dp, mp*parallel_config.expert_parallel), (), ()))
else:
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
# on/off value for onehot, for smooth labeling, modify the off_value
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)

View File

@ -94,10 +94,12 @@ def _check_moe_config(moe_config=None, parallel_config=None):
if device_num % parallel_config.expert_parallel != 0:
raise ValueError(f"device_num: {device_num} should be a multiple of expert_parallel: "
f"{parallel_config.expert_parallel}.")
if parallel_config.data_parallel * parallel_config.model_parallel * parallel_config.expert_parallel > device_num:
raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel}, "
f"model parallel: {parallel_config.model_parallel}, and "
f"expert parallel: {parallel_config.expert_parallel} "
if parallel_config.data_parallel % parallel_config.expert_parallel != 0:
raise ValueError(f"data parallel: {parallel_config.data_parallel} should be a multiple of "
f"expert_parallel: {parallel_config.expert_parallel}.")
if parallel_config.data_parallel * parallel_config.model_parallel > device_num:
raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel} and "
f"model parallel: {parallel_config.model_parallel} "
f"should be less than device_num: {device_num}.")

View File

@ -406,12 +406,13 @@ class FeedForward(Cell):
parallel_config=default_dpmp_config):
super(FeedForward, self).__init__()
_check_config(parallel_config)
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
if expert_num > 1:
ep = parallel_config.expert_parallel
else:
ep = 1
# ffn use less dp than other ops when use_moe, due to there are ops use dp and ep.
dp = int(parallel_config.data_parallel / ep)
if ffn_hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the"
"num of "

View File

@ -153,7 +153,7 @@ class TransformerNet(nn.Cell):
ffn_hidden_size=64,
moe_config=moe_config,
parallel_config=parallel_config)
self.loss = CrossEntropyLoss(parallel_config=parallel_config.moe_parallel_config)
self.loss = CrossEntropyLoss(parallel_config=parallel_config.dp_mp_config)
def construct(self, x1, x2, x3, x4, x5, y, mask):
predict, _, _ = self.network(x1, x2, x3, x4, x5)
@ -212,7 +212,8 @@ def test_moe_expert_parallel3():
Expectation: Successful graph compilation.
"""
local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=2)
moe_with_loss_plus_mutiparallel(local_p_config)
with pytest.raises(ValueError):
moe_with_loss_plus_mutiparallel(local_p_config)
def test_moe_expert_parallel_exception():
"""