!30326 If split indices is large than gradient's number, throw warnings.

Merge pull request !30326 from linqingke/r1.6
This commit is contained in:
i-robot 2022-02-23 07:21:00 +00:00 committed by Gitee
commit b97af2b6eb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 8 additions and 0 deletions

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""grad reducer cell for distributed training""" """grad reducer cell for distributed training"""
from mindspore import context from mindspore import context
from mindspore import log as logger
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.common.tensor import RowTensor from mindspore.common.tensor import RowTensor
@ -29,6 +30,11 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt")
def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP): def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP):
""" initialize allreduce communication operators""" """ initialize allreduce communication operators"""
for indices in split_indices:
if indices >= length:
logger.warning(f"AllReduce's split index {indices} is greater than or equal to "
f"the total gradient's number of {length}")
fusion_type = 2 ** 10 fusion_type = 2 ** 10
split = 0 split = 0
fusion = () fusion = ()
@ -39,6 +45,7 @@ def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM
if split_indices[split] <= i: if split_indices[split] <= i:
fusion_type += 1 fusion_type += 1
split += 1 split += 1
index = tuple(range(1, length + 1)) index = tuple(range(1, length + 1))
op_list = () op_list = ()
for i in range(length): for i in range(length):
@ -70,6 +77,7 @@ def _init_allreduce_operators_by_parameters(parameters, split_indices, group, fu
op.add_prim_attr('index', index) op.add_prim_attr('index', index)
index += 1 index += 1
op_list = op_list + (op,) op_list = op_list + (op,)
if not param_fusion: if not param_fusion:
if split_indices and fusion_type == 1: if split_indices and fusion_type == 1:
op_list = _init_allreduce_operators(len(parameters), split_indices, group) op_list = _init_allreduce_operators(len(parameters), split_indices, group)