split indices warning if large than gradient's number.

This commit is contained in:
linqingke 2022-02-21 14:40:33 +08:00
parent 6c301b6e1f
commit 009fb02ef4
1 changed files with 8 additions and 0 deletions

View File

@ -14,6 +14,7 @@
# ============================================================================
"""grad reducer cell for distributed training"""
from mindspore import context
from mindspore import log as logger
from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.common.tensor import RowTensor
@ -29,6 +30,11 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt")
def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP):
""" initialize allreduce communication operators"""
for indices in split_indices:
if indices >= length:
logger.warning(f"AllReduce's split index {indices} is greater than or equal to"
f"the total gradient's number of {length}")
fusion_type = 2 ** 10
split = 0
fusion = ()
@ -39,6 +45,7 @@ def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM
if split_indices[split] <= i:
fusion_type += 1
split += 1
index = tuple(range(1, length + 1))
op_list = ()
for i in range(length):
@ -70,6 +77,7 @@ def _init_allreduce_operators_by_parameters(parameters, split_indices, group, fu
op.add_prim_attr('index', index)
index += 1
op_list = op_list + (op,)
if not param_fusion:
if split_indices and fusion_type == 1:
op_list = _init_allreduce_operators(len(parameters), split_indices, group)