!30326 If split indices is large than gradient's number, throw warnings.
Merge pull request !30326 from linqingke/r1.6
This commit is contained in:
commit
b97af2b6eb
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""grad reducer cell for distributed training"""
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.communication.management import GlobalComm, get_group_size
|
||||
from mindspore.common.tensor import RowTensor
|
||||
|
@ -29,6 +30,11 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
|||
|
||||
def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
""" initialize allreduce communication operators"""
|
||||
for indices in split_indices:
|
||||
if indices >= length:
|
||||
logger.warning(f"AllReduce's split index {indices} is greater than or equal to "
|
||||
f"the total gradient's number of {length}")
|
||||
|
||||
fusion_type = 2 ** 10
|
||||
split = 0
|
||||
fusion = ()
|
||||
|
@ -39,6 +45,7 @@ def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM
|
|||
if split_indices[split] <= i:
|
||||
fusion_type += 1
|
||||
split += 1
|
||||
|
||||
index = tuple(range(1, length + 1))
|
||||
op_list = ()
|
||||
for i in range(length):
|
||||
|
@ -70,6 +77,7 @@ def _init_allreduce_operators_by_parameters(parameters, split_indices, group, fu
|
|||
op.add_prim_attr('index', index)
|
||||
index += 1
|
||||
op_list = op_list + (op,)
|
||||
|
||||
if not param_fusion:
|
||||
if split_indices and fusion_type == 1:
|
||||
op_list = _init_allreduce_operators(len(parameters), split_indices, group)
|
||||
|
|
Loading…
Reference in New Issue